Created
September 4, 2022 22:59
-
-
Save DannyWeitekamp/989e6f3a09f1e8cd51e42a2bebeb72d5 to your computer and use it in GitHub Desktop.
Revisions
-
DannyWeitekamp created this gist
Sep 4, 2022 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,265 @@ from numba import njit, f8 from numba.typed import List from numba.extending import models, register_model class Interval(object): """ A half-open interval on the real number line. """ def __init__(self, lo, hi): self.lo = lo self.hi = hi def __repr__(self): return 'Interval(%f, %f)' % (self.lo, self.hi) @property def width(self): return self.hi - self.lo from numba import types class IntervalType(types.IterableType): def __init__(self): super(IntervalType, self).__init__(name='Interval') @property def iterator_type(self): return IntervalIteratorType(self).iterator_type class IntervalIteratorType(types.SimpleIteratorType): def __init__(self, interval): self.parent = interval # self.iterable = iterable # yield_type = iterable.yield_type name = f"iter[{interval}]" super().__init__(name, f8) @register_model(IntervalIteratorType) class IntervalIterModel(models.StructModel): def __init__(self, dmm, fe_type): members = [ ('parent', fe_type.parent), # ('curr_val', types.float64) ('curr_val', types.EphemeralPointer(types.float64)) ] super().__init__(dmm, fe_type, members) from numba.extending import typeof_impl @typeof_impl.register(Interval) def typeof_index(val, c): return interval_type interval_type = IntervalType() # interval_iter_type = IntervalIteratorType() from numba.extending import as_numba_type from numba.extending import type_callable @type_callable(Interval) def type_interval(context): def typer(lo, hi): if isinstance(lo, types.Float) and isinstance(hi, types.Float): return interval_type return typer as_numba_type.register(Interval, interval_type) from numba.extending import models, register_model @register_model(IntervalType) class IntervalModel(models.StructModel): def __init__(self, dmm, fe_type): members = [ ('lo', types.float64), ('hi', types.float64), ] models.StructModel.__init__(self, dmm, fe_type, members) from numba.extending import make_attribute_wrapper make_attribute_wrapper(IntervalType, 'lo', 'lo') make_attribute_wrapper(IntervalType, 'hi', 'hi') from numba.extending import overload_attribute @overload_attribute(IntervalType, "width") def get_width(interval): def getter(interval): return interval.hi - interval.lo return getter from numba.extending import lower_builtin from numba.core import cgutils @lower_builtin(Interval, types.Float, types.Float) def impl_interval(context, builder, sig, args): typ = sig.return_type lo, hi = args interval = cgutils.create_struct_proxy(typ)(context, builder) interval.lo = lo interval.hi = hi return interval._getvalue() from numba.extending import unbox, NativeValue @unbox(IntervalType) def unbox_interval(typ, obj, c): """ Convert a Interval object to a native interval structure. """ lo_obj = c.pyapi.object_getattr_string(obj, "lo") hi_obj = c.pyapi.object_getattr_string(obj, "hi") interval = cgutils.create_struct_proxy(typ)(c.context, c.builder) interval.lo = c.pyapi.float_as_double(lo_obj) interval.hi = c.pyapi.float_as_double(hi_obj) c.pyapi.decref(lo_obj) c.pyapi.decref(hi_obj) is_error = cgutils.is_not_null(c.builder, c.pyapi.err_occurred()) return NativeValue(interval._getvalue(), is_error=is_error) from numba.extending import box @box(IntervalType) def box_interval(typ, val, c): """ Convert a native interval structure to an Interval object. """ interval = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val) lo_obj = c.pyapi.float_from_double(interval.lo) hi_obj = c.pyapi.float_from_double(interval.hi) class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(Interval)) res = c.pyapi.call_function_objargs(class_obj, (lo_obj, hi_obj)) c.pyapi.decref(lo_obj) c.pyapi.decref(hi_obj) c.pyapi.decref(class_obj) return res from numba import jit @jit(nopython=True) def inside_interval(interval, x): return interval.lo <= x < interval.hi @jit(nopython=True) def interval_width(interval): return interval.width @jit(nopython=True) def sum_intervals(i, j): return Interval(i.lo + j.lo, i.hi + j.hi) assert inside_interval(Interval(1.0,5.0),4) == True assert inside_interval(Interval(1.0,5.0),6) == False print(interval_width(Interval(1.0,6.0))) print(sum_intervals(Interval(1.0,6.0),Interval(1.0,6.0))) ########### ## ^ Above all from https://numba.pydata.org/numba-doc/latest/extending/interval-example.html ## v Below a test of implementing "getiter" ########### from numba import f8 from numba.core.imputils import lower_builtin, iternext_impl, RefType, impl_ret_borrowed @lower_builtin("getiter", interval_type) def iterval_getiter(context, builder, sig, args): interval = cgutils.create_struct_proxy(sig.args[0])(context, builder, value=args[0]) _iter = cgutils.create_struct_proxy(sig.return_type)(context, builder) _iter.parent = args[0] _iter.curr_val = cgutils.alloca_once_value(builder, interval.lo) return _iter._getvalue() @lower_builtin('iternext', IntervalIteratorType) @iternext_impl(RefType.BORROWED) def iternext_listiter(context, builder, sig, args, result): # Make iter and interval proxies _iter = cgutils.create_struct_proxy(sig.args[0])(context, builder, value=args[0]) interval = cgutils.create_struct_proxy(interval_type)(context, builder, value=_iter.parent) # load value and curr_val = builder.load(_iter.curr_val) is_valid = builder.fcmp_ordered('<', curr_val, interval.hi) result.set_valid(is_valid) with builder.if_then(is_valid): result.yield_(curr_val) new_val = builder.fadd(curr_val, context.get_constant(types.float64, 1.0)) builder.store(new_val, _iter.curr_val) @jit(nopython=True) def iter_interval(lo, hi): for v in Interval(f8(lo),f8(hi)): print("<<", v) iter_interval(1,10) #------------------------------------------- # : Speed Tests import numpy as np @jit(nopython=True) def accum_numpy(lo, hi): return np.sum(np.arange(lo,hi,dtype=np.float64)) @jit(nopython=True) def accum_interval(lo, hi): a = 0 for v in Interval(f8(lo),f8(hi)): a += v return a @jit(nopython=True) def accum_loop(lo, hi): a = f8(0.0) v = f8(0.0) for v in range(lo, hi): a += v v += 1.0 return a import time class PrintElapse(): def __init__(self, name): self.name = name def __enter__(self): self.t0 = time.time_ns()/float(1e6) def __exit__(self,*args): self.t1 = time.time_ns()/float(1e6) print(f'{self.name}: {self.t1-self.t0:.2f} ms') N = 1000000 res = np.sum(np.arange(N)) assert accum_numpy(0,N) == res with PrintElapse("accum_numpy"): accum_numpy(0,N) assert accum_interval(0,N) == res with PrintElapse("accum_interval"): accum_interval(0,N) assert accum_loop(0,N) == res with PrintElapse("accum_loop"): accum_loop(0,N)