Skip to content

Instantly share code, notes, and snippets.

@DannyWeitekamp
Created September 4, 2022 22:59
Show Gist options
  • Select an option

  • Save DannyWeitekamp/989e6f3a09f1e8cd51e42a2bebeb72d5 to your computer and use it in GitHub Desktop.

Select an option

Save DannyWeitekamp/989e6f3a09f1e8cd51e42a2bebeb72d5 to your computer and use it in GitHub Desktop.

Revisions

  1. DannyWeitekamp created this gist Sep 4, 2022.
    265 changes: 265 additions & 0 deletions Numba_Iterator.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,265 @@
    from numba import njit, f8
    from numba.typed import List
    from numba.extending import models, register_model

    class Interval(object):
    """
    A half-open interval on the real number line.
    """
    def __init__(self, lo, hi):
    self.lo = lo
    self.hi = hi

    def __repr__(self):
    return 'Interval(%f, %f)' % (self.lo, self.hi)

    @property
    def width(self):
    return self.hi - self.lo


    from numba import types


    class IntervalType(types.IterableType):
    def __init__(self):
    super(IntervalType, self).__init__(name='Interval')

    @property
    def iterator_type(self):
    return IntervalIteratorType(self).iterator_type


    class IntervalIteratorType(types.SimpleIteratorType):
    def __init__(self, interval):
    self.parent = interval
    # self.iterable = iterable
    # yield_type = iterable.yield_type
    name = f"iter[{interval}]"
    super().__init__(name, f8)


    @register_model(IntervalIteratorType)
    class IntervalIterModel(models.StructModel):
    def __init__(self, dmm, fe_type):
    members = [
    ('parent', fe_type.parent),
    # ('curr_val', types.float64)
    ('curr_val', types.EphemeralPointer(types.float64))
    ]
    super().__init__(dmm, fe_type, members)


    from numba.extending import typeof_impl

    @typeof_impl.register(Interval)
    def typeof_index(val, c):
    return interval_type

    interval_type = IntervalType()
    # interval_iter_type = IntervalIteratorType()


    from numba.extending import as_numba_type
    from numba.extending import type_callable

    @type_callable(Interval)
    def type_interval(context):
    def typer(lo, hi):
    if isinstance(lo, types.Float) and isinstance(hi, types.Float):
    return interval_type
    return typer

    as_numba_type.register(Interval, interval_type)


    from numba.extending import models, register_model

    @register_model(IntervalType)
    class IntervalModel(models.StructModel):
    def __init__(self, dmm, fe_type):
    members = [
    ('lo', types.float64),
    ('hi', types.float64),
    ]
    models.StructModel.__init__(self, dmm, fe_type, members)

    from numba.extending import make_attribute_wrapper

    make_attribute_wrapper(IntervalType, 'lo', 'lo')
    make_attribute_wrapper(IntervalType, 'hi', 'hi')


    from numba.extending import overload_attribute

    @overload_attribute(IntervalType, "width")
    def get_width(interval):
    def getter(interval):
    return interval.hi - interval.lo
    return getter


    from numba.extending import lower_builtin
    from numba.core import cgutils

    @lower_builtin(Interval, types.Float, types.Float)
    def impl_interval(context, builder, sig, args):
    typ = sig.return_type
    lo, hi = args
    interval = cgutils.create_struct_proxy(typ)(context, builder)
    interval.lo = lo
    interval.hi = hi
    return interval._getvalue()

    from numba.extending import unbox, NativeValue

    @unbox(IntervalType)
    def unbox_interval(typ, obj, c):
    """
    Convert a Interval object to a native interval structure.
    """
    lo_obj = c.pyapi.object_getattr_string(obj, "lo")
    hi_obj = c.pyapi.object_getattr_string(obj, "hi")
    interval = cgutils.create_struct_proxy(typ)(c.context, c.builder)
    interval.lo = c.pyapi.float_as_double(lo_obj)
    interval.hi = c.pyapi.float_as_double(hi_obj)
    c.pyapi.decref(lo_obj)
    c.pyapi.decref(hi_obj)
    is_error = cgutils.is_not_null(c.builder, c.pyapi.err_occurred())
    return NativeValue(interval._getvalue(), is_error=is_error)


    from numba.extending import box

    @box(IntervalType)
    def box_interval(typ, val, c):
    """
    Convert a native interval structure to an Interval object.
    """
    interval = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val)
    lo_obj = c.pyapi.float_from_double(interval.lo)
    hi_obj = c.pyapi.float_from_double(interval.hi)
    class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(Interval))
    res = c.pyapi.call_function_objargs(class_obj, (lo_obj, hi_obj))
    c.pyapi.decref(lo_obj)
    c.pyapi.decref(hi_obj)
    c.pyapi.decref(class_obj)
    return res


    from numba import jit

    @jit(nopython=True)
    def inside_interval(interval, x):
    return interval.lo <= x < interval.hi

    @jit(nopython=True)
    def interval_width(interval):
    return interval.width

    @jit(nopython=True)
    def sum_intervals(i, j):
    return Interval(i.lo + j.lo, i.hi + j.hi)


    assert inside_interval(Interval(1.0,5.0),4) == True
    assert inside_interval(Interval(1.0,5.0),6) == False

    print(interval_width(Interval(1.0,6.0)))
    print(sum_intervals(Interval(1.0,6.0),Interval(1.0,6.0)))

    ###########
    ## ^ Above all from https://numba.pydata.org/numba-doc/latest/extending/interval-example.html
    ## v Below a test of implementing "getiter"
    ###########

    from numba import f8
    from numba.core.imputils import lower_builtin, iternext_impl, RefType, impl_ret_borrowed
    @lower_builtin("getiter", interval_type)
    def iterval_getiter(context, builder, sig, args):
    interval = cgutils.create_struct_proxy(sig.args[0])(context, builder, value=args[0])
    _iter = cgutils.create_struct_proxy(sig.return_type)(context, builder)
    _iter.parent = args[0]
    _iter.curr_val = cgutils.alloca_once_value(builder, interval.lo)
    return _iter._getvalue()


    @lower_builtin('iternext', IntervalIteratorType)
    @iternext_impl(RefType.BORROWED)
    def iternext_listiter(context, builder, sig, args, result):
    # Make iter and interval proxies
    _iter = cgutils.create_struct_proxy(sig.args[0])(context, builder, value=args[0])
    interval = cgutils.create_struct_proxy(interval_type)(context, builder, value=_iter.parent)

    # load value and
    curr_val = builder.load(_iter.curr_val)
    is_valid = builder.fcmp_ordered('<', curr_val, interval.hi)
    result.set_valid(is_valid)

    with builder.if_then(is_valid):
    result.yield_(curr_val)
    new_val = builder.fadd(curr_val, context.get_constant(types.float64, 1.0))
    builder.store(new_val, _iter.curr_val)


    @jit(nopython=True)
    def iter_interval(lo, hi):
    for v in Interval(f8(lo),f8(hi)):
    print("<<", v)

    iter_interval(1,10)


    #-------------------------------------------
    # : Speed Tests


    import numpy as np
    @jit(nopython=True)
    def accum_numpy(lo, hi):
    return np.sum(np.arange(lo,hi,dtype=np.float64))

    @jit(nopython=True)
    def accum_interval(lo, hi):
    a = 0
    for v in Interval(f8(lo),f8(hi)):
    a += v
    return a

    @jit(nopython=True)
    def accum_loop(lo, hi):
    a = f8(0.0)
    v = f8(0.0)
    for v in range(lo, hi):
    a += v
    v += 1.0
    return a

    import time
    class PrintElapse():
    def __init__(self, name):
    self.name = name
    def __enter__(self):
    self.t0 = time.time_ns()/float(1e6)
    def __exit__(self,*args):
    self.t1 = time.time_ns()/float(1e6)
    print(f'{self.name}: {self.t1-self.t0:.2f} ms')

    N = 1000000
    res = np.sum(np.arange(N))

    assert accum_numpy(0,N) == res
    with PrintElapse("accum_numpy"):
    accum_numpy(0,N)

    assert accum_interval(0,N) == res
    with PrintElapse("accum_interval"):
    accum_interval(0,N)

    assert accum_loop(0,N) == res
    with PrintElapse("accum_loop"):
    accum_loop(0,N)