Created
September 4, 2022 22:59
-
-
Save DannyWeitekamp/989e6f3a09f1e8cd51e42a2bebeb72d5 to your computer and use it in GitHub Desktop.
Example of writing an iterator for Interval() struct shown in the numba docs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from numba import njit, f8 | |
| from numba.typed import List | |
| from numba.extending import models, register_model | |
| class Interval(object): | |
| """ | |
| A half-open interval on the real number line. | |
| """ | |
| def __init__(self, lo, hi): | |
| self.lo = lo | |
| self.hi = hi | |
| def __repr__(self): | |
| return 'Interval(%f, %f)' % (self.lo, self.hi) | |
| @property | |
| def width(self): | |
| return self.hi - self.lo | |
| from numba import types | |
| class IntervalType(types.IterableType): | |
| def __init__(self): | |
| super(IntervalType, self).__init__(name='Interval') | |
| @property | |
| def iterator_type(self): | |
| return IntervalIteratorType(self).iterator_type | |
| class IntervalIteratorType(types.SimpleIteratorType): | |
| def __init__(self, interval): | |
| self.parent = interval | |
| # self.iterable = iterable | |
| # yield_type = iterable.yield_type | |
| name = f"iter[{interval}]" | |
| super().__init__(name, f8) | |
| @register_model(IntervalIteratorType) | |
| class IntervalIterModel(models.StructModel): | |
| def __init__(self, dmm, fe_type): | |
| members = [ | |
| ('parent', fe_type.parent), | |
| # ('curr_val', types.float64) | |
| ('curr_val', types.EphemeralPointer(types.float64)) | |
| ] | |
| super().__init__(dmm, fe_type, members) | |
| from numba.extending import typeof_impl | |
| @typeof_impl.register(Interval) | |
| def typeof_index(val, c): | |
| return interval_type | |
| interval_type = IntervalType() | |
| # interval_iter_type = IntervalIteratorType() | |
| from numba.extending import as_numba_type | |
| from numba.extending import type_callable | |
| @type_callable(Interval) | |
| def type_interval(context): | |
| def typer(lo, hi): | |
| if isinstance(lo, types.Float) and isinstance(hi, types.Float): | |
| return interval_type | |
| return typer | |
| as_numba_type.register(Interval, interval_type) | |
| from numba.extending import models, register_model | |
| @register_model(IntervalType) | |
| class IntervalModel(models.StructModel): | |
| def __init__(self, dmm, fe_type): | |
| members = [ | |
| ('lo', types.float64), | |
| ('hi', types.float64), | |
| ] | |
| models.StructModel.__init__(self, dmm, fe_type, members) | |
| from numba.extending import make_attribute_wrapper | |
| make_attribute_wrapper(IntervalType, 'lo', 'lo') | |
| make_attribute_wrapper(IntervalType, 'hi', 'hi') | |
| from numba.extending import overload_attribute | |
| @overload_attribute(IntervalType, "width") | |
| def get_width(interval): | |
| def getter(interval): | |
| return interval.hi - interval.lo | |
| return getter | |
| from numba.extending import lower_builtin | |
| from numba.core import cgutils | |
| @lower_builtin(Interval, types.Float, types.Float) | |
| def impl_interval(context, builder, sig, args): | |
| typ = sig.return_type | |
| lo, hi = args | |
| interval = cgutils.create_struct_proxy(typ)(context, builder) | |
| interval.lo = lo | |
| interval.hi = hi | |
| return interval._getvalue() | |
| from numba.extending import unbox, NativeValue | |
| @unbox(IntervalType) | |
| def unbox_interval(typ, obj, c): | |
| """ | |
| Convert a Interval object to a native interval structure. | |
| """ | |
| lo_obj = c.pyapi.object_getattr_string(obj, "lo") | |
| hi_obj = c.pyapi.object_getattr_string(obj, "hi") | |
| interval = cgutils.create_struct_proxy(typ)(c.context, c.builder) | |
| interval.lo = c.pyapi.float_as_double(lo_obj) | |
| interval.hi = c.pyapi.float_as_double(hi_obj) | |
| c.pyapi.decref(lo_obj) | |
| c.pyapi.decref(hi_obj) | |
| is_error = cgutils.is_not_null(c.builder, c.pyapi.err_occurred()) | |
| return NativeValue(interval._getvalue(), is_error=is_error) | |
| from numba.extending import box | |
| @box(IntervalType) | |
| def box_interval(typ, val, c): | |
| """ | |
| Convert a native interval structure to an Interval object. | |
| """ | |
| interval = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val) | |
| lo_obj = c.pyapi.float_from_double(interval.lo) | |
| hi_obj = c.pyapi.float_from_double(interval.hi) | |
| class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(Interval)) | |
| res = c.pyapi.call_function_objargs(class_obj, (lo_obj, hi_obj)) | |
| c.pyapi.decref(lo_obj) | |
| c.pyapi.decref(hi_obj) | |
| c.pyapi.decref(class_obj) | |
| return res | |
| from numba import jit | |
| @jit(nopython=True) | |
| def inside_interval(interval, x): | |
| return interval.lo <= x < interval.hi | |
| @jit(nopython=True) | |
| def interval_width(interval): | |
| return interval.width | |
| @jit(nopython=True) | |
| def sum_intervals(i, j): | |
| return Interval(i.lo + j.lo, i.hi + j.hi) | |
| assert inside_interval(Interval(1.0,5.0),4) == True | |
| assert inside_interval(Interval(1.0,5.0),6) == False | |
| print(interval_width(Interval(1.0,6.0))) | |
| print(sum_intervals(Interval(1.0,6.0),Interval(1.0,6.0))) | |
| ########### | |
| ## ^ Above all from https://numba.pydata.org/numba-doc/latest/extending/interval-example.html | |
| ## v Below a test of implementing "getiter" | |
| ########### | |
| from numba import f8 | |
| from numba.core.imputils import lower_builtin, iternext_impl, RefType, impl_ret_borrowed | |
| @lower_builtin("getiter", interval_type) | |
| def iterval_getiter(context, builder, sig, args): | |
| interval = cgutils.create_struct_proxy(sig.args[0])(context, builder, value=args[0]) | |
| _iter = cgutils.create_struct_proxy(sig.return_type)(context, builder) | |
| _iter.parent = args[0] | |
| _iter.curr_val = cgutils.alloca_once_value(builder, interval.lo) | |
| return _iter._getvalue() | |
| @lower_builtin('iternext', IntervalIteratorType) | |
| @iternext_impl(RefType.BORROWED) | |
| def iternext_listiter(context, builder, sig, args, result): | |
| # Make iter and interval proxies | |
| _iter = cgutils.create_struct_proxy(sig.args[0])(context, builder, value=args[0]) | |
| interval = cgutils.create_struct_proxy(interval_type)(context, builder, value=_iter.parent) | |
| # load value and | |
| curr_val = builder.load(_iter.curr_val) | |
| is_valid = builder.fcmp_ordered('<', curr_val, interval.hi) | |
| result.set_valid(is_valid) | |
| with builder.if_then(is_valid): | |
| result.yield_(curr_val) | |
| new_val = builder.fadd(curr_val, context.get_constant(types.float64, 1.0)) | |
| builder.store(new_val, _iter.curr_val) | |
| @jit(nopython=True) | |
| def iter_interval(lo, hi): | |
| for v in Interval(f8(lo),f8(hi)): | |
| print("<<", v) | |
| iter_interval(1,10) | |
| #------------------------------------------- | |
| # : Speed Tests | |
| import numpy as np | |
| @jit(nopython=True) | |
| def accum_numpy(lo, hi): | |
| return np.sum(np.arange(lo,hi,dtype=np.float64)) | |
| @jit(nopython=True) | |
| def accum_interval(lo, hi): | |
| a = 0 | |
| for v in Interval(f8(lo),f8(hi)): | |
| a += v | |
| return a | |
| @jit(nopython=True) | |
| def accum_loop(lo, hi): | |
| a = f8(0.0) | |
| v = f8(0.0) | |
| for v in range(lo, hi): | |
| a += v | |
| v += 1.0 | |
| return a | |
| import time | |
| class PrintElapse(): | |
| def __init__(self, name): | |
| self.name = name | |
| def __enter__(self): | |
| self.t0 = time.time_ns()/float(1e6) | |
| def __exit__(self,*args): | |
| self.t1 = time.time_ns()/float(1e6) | |
| print(f'{self.name}: {self.t1-self.t0:.2f} ms') | |
| N = 1000000 | |
| res = np.sum(np.arange(N)) | |
| assert accum_numpy(0,N) == res | |
| with PrintElapse("accum_numpy"): | |
| accum_numpy(0,N) | |
| assert accum_interval(0,N) == res | |
| with PrintElapse("accum_interval"): | |
| accum_interval(0,N) | |
| assert accum_loop(0,N) == res | |
| with PrintElapse("accum_loop"): | |
| accum_loop(0,N) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment