Skip to content

Instantly share code, notes, and snippets.

@DannyWeitekamp
Last active November 17, 2022 19:26
Show Gist options
  • Select an option

  • Save DannyWeitekamp/612b4cdabc0fa1ee469587e4c2b4772b to your computer and use it in GitHub Desktop.

Select an option

Save DannyWeitekamp/612b4cdabc0fa1ee469587e4c2b4772b to your computer and use it in GitHub Desktop.
Examples of how to speed up python execution of numba compiled functions
from numba import njit, cfunc
from numba.types import unicode_type, i8
import time
class PrintElapse():
def __init__(self, name):
self.name = name
def __enter__(self):
self.t0 = time.time_ns()/float(1e6)
def __exit__(self,*args):
self.t1 = time.time_ns()/float(1e6)
print(f'{self.name}: {self.t1-self.t0:.2f} ms')
@njit(cache=True)
def foo(x,a0,a1,a2):
return len(x)
@cfunc(i8(unicode_type, i8, i8, i8), nopython=True, cache=True)
def bar(x,a0,a1,a2):
return len(x)
@njit(i8(unicode_type, i8, i8, i8), nopython=True, cache=True)
def baz(x,a0,a1,a2):
return len(x)
baz_entry_point = baz.overloads[(unicode_type, i8, i8, i8)].entry_point
print(foo("HELLO WORLD",0,1,2))
print(bar("HELLO WORLD","MOOP",1,2))
print(bar("HELLO WORLD",0,1,2))
print(baz_entry_point("HELLO WORLD",0,1,2))
with PrintElapse("njit"):
for i in range(1000):
foo("HELLO WORLD",i,1,2)
with PrintElapse("python"):
f = foo.py_func
for i in range(1000):
f("HELLO WORLD",i,1,2)
with PrintElapse("cfunc"):
for i in range(1000):
bar("HELLO WORLD",i,1,2)
with PrintElapse("entry_point"):
for i in range(1000):
baz_entry_point("HELLO WORLD",i,1,2)
@DannyWeitekamp
Copy link
Author

On my machine:

njit: 1.40 ms
python: 0.07 ms
cfunc: 0.20 ms
entry_point: 0.11 ms

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment