Skip to content

Instantly share code, notes, and snippets.

@DannyWeitekamp
Created September 4, 2022 22:59
Show Gist options
  • Select an option

  • Save DannyWeitekamp/989e6f3a09f1e8cd51e42a2bebeb72d5 to your computer and use it in GitHub Desktop.

Select an option

Save DannyWeitekamp/989e6f3a09f1e8cd51e42a2bebeb72d5 to your computer and use it in GitHub Desktop.
Example of writing an iterator for Interval() struct shown in the numba docs
from numba import njit, f8
from numba.typed import List
from numba.extending import models, register_model
class Interval(object):
"""
A half-open interval on the real number line.
"""
def __init__(self, lo, hi):
self.lo = lo
self.hi = hi
def __repr__(self):
return 'Interval(%f, %f)' % (self.lo, self.hi)
@property
def width(self):
return self.hi - self.lo
from numba import types
class IntervalType(types.IterableType):
def __init__(self):
super(IntervalType, self).__init__(name='Interval')
@property
def iterator_type(self):
return IntervalIteratorType(self).iterator_type
class IntervalIteratorType(types.SimpleIteratorType):
def __init__(self, interval):
self.parent = interval
# self.iterable = iterable
# yield_type = iterable.yield_type
name = f"iter[{interval}]"
super().__init__(name, f8)
@register_model(IntervalIteratorType)
class IntervalIterModel(models.StructModel):
def __init__(self, dmm, fe_type):
members = [
('parent', fe_type.parent),
# ('curr_val', types.float64)
('curr_val', types.EphemeralPointer(types.float64))
]
super().__init__(dmm, fe_type, members)
from numba.extending import typeof_impl
@typeof_impl.register(Interval)
def typeof_index(val, c):
return interval_type
interval_type = IntervalType()
# interval_iter_type = IntervalIteratorType()
from numba.extending import as_numba_type
from numba.extending import type_callable
@type_callable(Interval)
def type_interval(context):
def typer(lo, hi):
if isinstance(lo, types.Float) and isinstance(hi, types.Float):
return interval_type
return typer
as_numba_type.register(Interval, interval_type)
from numba.extending import models, register_model
@register_model(IntervalType)
class IntervalModel(models.StructModel):
def __init__(self, dmm, fe_type):
members = [
('lo', types.float64),
('hi', types.float64),
]
models.StructModel.__init__(self, dmm, fe_type, members)
from numba.extending import make_attribute_wrapper
make_attribute_wrapper(IntervalType, 'lo', 'lo')
make_attribute_wrapper(IntervalType, 'hi', 'hi')
from numba.extending import overload_attribute
@overload_attribute(IntervalType, "width")
def get_width(interval):
def getter(interval):
return interval.hi - interval.lo
return getter
from numba.extending import lower_builtin
from numba.core import cgutils
@lower_builtin(Interval, types.Float, types.Float)
def impl_interval(context, builder, sig, args):
typ = sig.return_type
lo, hi = args
interval = cgutils.create_struct_proxy(typ)(context, builder)
interval.lo = lo
interval.hi = hi
return interval._getvalue()
from numba.extending import unbox, NativeValue
@unbox(IntervalType)
def unbox_interval(typ, obj, c):
"""
Convert a Interval object to a native interval structure.
"""
lo_obj = c.pyapi.object_getattr_string(obj, "lo")
hi_obj = c.pyapi.object_getattr_string(obj, "hi")
interval = cgutils.create_struct_proxy(typ)(c.context, c.builder)
interval.lo = c.pyapi.float_as_double(lo_obj)
interval.hi = c.pyapi.float_as_double(hi_obj)
c.pyapi.decref(lo_obj)
c.pyapi.decref(hi_obj)
is_error = cgutils.is_not_null(c.builder, c.pyapi.err_occurred())
return NativeValue(interval._getvalue(), is_error=is_error)
from numba.extending import box
@box(IntervalType)
def box_interval(typ, val, c):
"""
Convert a native interval structure to an Interval object.
"""
interval = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val)
lo_obj = c.pyapi.float_from_double(interval.lo)
hi_obj = c.pyapi.float_from_double(interval.hi)
class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(Interval))
res = c.pyapi.call_function_objargs(class_obj, (lo_obj, hi_obj))
c.pyapi.decref(lo_obj)
c.pyapi.decref(hi_obj)
c.pyapi.decref(class_obj)
return res
from numba import jit
@jit(nopython=True)
def inside_interval(interval, x):
return interval.lo <= x < interval.hi
@jit(nopython=True)
def interval_width(interval):
return interval.width
@jit(nopython=True)
def sum_intervals(i, j):
return Interval(i.lo + j.lo, i.hi + j.hi)
assert inside_interval(Interval(1.0,5.0),4) == True
assert inside_interval(Interval(1.0,5.0),6) == False
print(interval_width(Interval(1.0,6.0)))
print(sum_intervals(Interval(1.0,6.0),Interval(1.0,6.0)))
###########
## ^ Above all from https://numba.pydata.org/numba-doc/latest/extending/interval-example.html
## v Below a test of implementing "getiter"
###########
from numba import f8
from numba.core.imputils import lower_builtin, iternext_impl, RefType, impl_ret_borrowed
@lower_builtin("getiter", interval_type)
def iterval_getiter(context, builder, sig, args):
interval = cgutils.create_struct_proxy(sig.args[0])(context, builder, value=args[0])
_iter = cgutils.create_struct_proxy(sig.return_type)(context, builder)
_iter.parent = args[0]
_iter.curr_val = cgutils.alloca_once_value(builder, interval.lo)
return _iter._getvalue()
@lower_builtin('iternext', IntervalIteratorType)
@iternext_impl(RefType.BORROWED)
def iternext_listiter(context, builder, sig, args, result):
# Make iter and interval proxies
_iter = cgutils.create_struct_proxy(sig.args[0])(context, builder, value=args[0])
interval = cgutils.create_struct_proxy(interval_type)(context, builder, value=_iter.parent)
# load value and
curr_val = builder.load(_iter.curr_val)
is_valid = builder.fcmp_ordered('<', curr_val, interval.hi)
result.set_valid(is_valid)
with builder.if_then(is_valid):
result.yield_(curr_val)
new_val = builder.fadd(curr_val, context.get_constant(types.float64, 1.0))
builder.store(new_val, _iter.curr_val)
@jit(nopython=True)
def iter_interval(lo, hi):
for v in Interval(f8(lo),f8(hi)):
print("<<", v)
iter_interval(1,10)
#-------------------------------------------
# : Speed Tests
import numpy as np
@jit(nopython=True)
def accum_numpy(lo, hi):
return np.sum(np.arange(lo,hi,dtype=np.float64))
@jit(nopython=True)
def accum_interval(lo, hi):
a = 0
for v in Interval(f8(lo),f8(hi)):
a += v
return a
@jit(nopython=True)
def accum_loop(lo, hi):
a = f8(0.0)
v = f8(0.0)
for v in range(lo, hi):
a += v
v += 1.0
return a
import time
class PrintElapse():
def __init__(self, name):
self.name = name
def __enter__(self):
self.t0 = time.time_ns()/float(1e6)
def __exit__(self,*args):
self.t1 = time.time_ns()/float(1e6)
print(f'{self.name}: {self.t1-self.t0:.2f} ms')
N = 1000000
res = np.sum(np.arange(N))
assert accum_numpy(0,N) == res
with PrintElapse("accum_numpy"):
accum_numpy(0,N)
assert accum_interval(0,N) == res
with PrintElapse("accum_interval"):
accum_interval(0,N)
assert accum_loop(0,N) == res
with PrintElapse("accum_loop"):
accum_loop(0,N)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment