# Mandelbrot in Mojo with Python plots Not only Mojo is great for writing high-performance code, but it also allows us to leverage huge Python ecosystem of libraries and tools. With seamless Python interoperability, Mojo can use Python for what it's good at, especially GUIs, without sacrificing performance in critical code. Let's take the classic Mandelbrot set algorithm and implement it in Mojo. We'll introduce a `Complex` type and use it in our implementation. ## Mandelbrot in python ```mojo %%python import numpy as np import numba import matplotlib.pyplot as plt import matplotlib.colors as colors import time ``` ```mojo %%python # Constants xmin = -2.25 xmax = 0.75 xn = 450 ymin = -1.25 ymax = 1.25 yn = 375 max_iter = 200 # Compute the number of steps to escape def mandelbrot_kernel(c): z = c for i in range(max_iter): z = z * z + c if abs(z) > 2: return i return max_iter def mandelbrot(): # Create a matrix. Each element of the matrix corresponds to a pixel result = np.zeros((yn, xn), dtype=np.uint32) dx = (xmax - xmin) / xn dy = (ymax - ymin) / yn y = ymin for j in range(yn): x = xmin for i in range(xn): result[j, i] = mandelbrot_kernel(complex(x, y)) x += dx y += dy return result def make_plot_python(m): dpi = 32 width = 5 height = 5 * yn // xn fig = plt.figure(1, [width, height], dpi=dpi) ax = fig.add_axes([0.0, 0.0, 1.0, 1.0], frame_on=False, aspect=1) light = colors.LightSource(315, 10, 0, 1, 1, 0) image = light.shade(m, plt.cm.hot, colors.PowerNorm(0.3), blend_mode='hsv', vert_exag=1.5) plt.imshow(image) plt.axis("off") plt.show() ``` ```mojo %%python start_time = time.time() mandelbrot_set = mandelbrot() end_time = time.time() execution_time = (end_time - start_time) * 1000 # Make it milliseconds make_plot_python(mandelbrot_set) print(f"Execution time for Python Mandelbrot: {execution_time:.0f} ms") ``` ![output_4_0](https://user-images.githubusercontent.com/6831355/237290290-7f640355-555c-4836-adbf-c280b74d06e6.png) Execution time for Python Mandelbrot: 1266 ms ## Python numba JIT compiler ```mojo %%python # Run with Numba JIT compiler @numba.jit(nopython=True) def mandelbrot_kernel_numba(c): z = c for i in range(max_iter): z = z * z + c if abs(z) > 2: return i return max_iter @numba.jit(nopython=True) def mandelbrot_numba(): # Create a matrix. Each element of the matrix corresponds to a pixel result = np.zeros((yn, xn), dtype=np.uint32) dx = (xmax - xmin) / xn dy = (ymax - ymin) / yn y = ymin for j in range(yn): x = xmin for i in range(xn): result[j, i] = mandelbrot_kernel_numba(complex(x, y)) x += dx y += dy return result ``` ```mojo %%python dummy = mandelbrot_numba() # Compile numba first start_time = time.time() mandelbrot_set = mandelbrot_numba() end_time = time.time() execution_time = (end_time - start_time) * 1000 # Make it milliseconds make_plot_python(mandelbrot_set) print(f"Execution time for Python Mandelbrot (numba): {execution_time:.0f} ms") ``` ![output_7_0](https://user-images.githubusercontent.com/6831355/237290295-1a1b4a54-b3ee-4a2f-ab2d-57322e2f19b1.png) Execution time for Python Mandelbrot (numba): 60 ms ## Python vectorized ```mojo %%python def mandelbrot_vectorized(xn, yn, max_iter=200): # Define the boundaries of the complex plane xmin = -2.25 xmax = 0.75 ymin = -1.25 ymax = 1.25 # Create the grid of complex numbers x = np.linspace(xmin, xmax, xn) y = np.linspace(ymin, ymax, yn) c = np.array([[complex(re, im) for re in x] for im in y]) # Initialize the Mandelbrot set and iteration count array mandelbrot_set = np.zeros((yn, xn), dtype=np.uint32) iter_count = np.zeros_like(mandelbrot_set) # Initialize the z values with the complex grid z = c.copy() # Iterate over each point using vectorized operations for i in range(max_iter): # Update z values based on the Mandelbrot equation z = z**2 + c # Update the iteration count for points that have not escaped iter_count[(np.abs(z) < 2) & (mandelbrot_set == 0)] = i # Mark points that have escaped mandelbrot_set[np.abs(z) >= 2] = 1 # Replace points that never escaped with the maximum iteration count iter_count[mandelbrot_set == 0] = max_iter return iter_count ``` ```mojo %%python start_time = time.time() mandelbrot_set = mandelbrot_vectorized(xn, yn) end_time = time.time() execution_time = (end_time - start_time) * 1000 make_plot_python(mandelbrot_set) print(f"Execution time for Python Mandelbrot (vectorized): {execution_time:.0f} ms") ``` :47: RuntimeWarning: overflow encountered in square :47: RuntimeWarning: invalid value encountered in square ![output_10_1](https://user-images.githubusercontent.com/6831355/237290298-0d056af6-e486-4e4d-abf9-ed360dd23609.png) Execution time for Python Mandelbrot (vectorized): 239 ms ## Python vectorized numba JIT ```mojo %%python @numba.vectorize([numba.uint32(numba.complex128, numba.uint32)], nopython=True) def mandelbrot_element(c, max_iter): z = c for i in range(max_iter): z = z**2 + c if abs(z) > 2: return i return max_iter def mandelbrot_vectorized_numba(xn, yn, max_iter=200): # Define the boundaries of the complex plane xmin = -2.25 xmax = 0.75 ymin = -1.25 ymax = 1.25 # Create the grid of complex numbers x = np.linspace(xmin, xmax, xn) y = np.linspace(ymin, ymax, yn) c = np.array([[complex(re, im) for re in x] for im in y]) # Compute the Mandelbrot set element-wise using the vectorized function iter_count = mandelbrot_element(c, max_iter) return iter_count ``` ```mojo %%python dummy = mandelbrot_vectorized_numba(xn, yn) # Compile numba first start_time = time.time() mandelbrot_set = mandelbrot_vectorized_numba(xn, yn) end_time = time.time() execution_time = (end_time - start_time) * 1000 make_plot_python(mandelbrot_set) print(f"Execution time for Python Mandelbrot (vectorized-numba): {execution_time:.0f} ms") ``` ![output_13_0](https://user-images.githubusercontent.com/6831355/237290301-9276527a-7fa3-43f4-a142-d7049329292f.png) Execution time for Python Mandelbrot (vectorized-numba): 93 ms ## Cython (can't load extension) ```mojo # %%python # import os # os.system('pip install cython') ``` ```mojo # %load_ext cython ``` ```mojo # %%cython # ## Try with cython # import numpy as np # cimport numpy as np # # Constants # cdef double xmin = -2.25 # cdef double xmax = 0.75 # cdef int xn = 450 # cdef double ymin = -1.25 # cdef double ymax = 1.25 # cdef int yn = 375 # cdef int max_iter = 200 # # Mandelbrot computation in Cython # cpdef np.ndarray[np.uint32_t, ndim=2] mandelbrot_cython(): # cdef double dx = (xmax - xmin) / xn # cdef double dy = (ymax - ymin) / yn # cdef np.ndarray[np.uint32_t, ndim=2] result = np.zeros((yn, xn), dtype=np.uint32) # cdef double x, y, real, imag, abs_val # cdef int i, j, k # y = ymin # for j in range(yn): # x = xmin # for i in range(xn): # real = x # imag = y # for k in range(max_iter): # abs_val = real * real + imag * imag # if abs_val > 4: # break # real, imag = real * real - imag * imag + x, 2 * real * imag + y # result[j, i] = k # x += dx # y += dy # return result ``` ```mojo # start_time = time.time() # mandelbrot_set = mandelbrot_cython() # end_time = time.time() # execution_time = (end_time - start_time) * 1000 # Make it milliseconds # make_plot_python(mandelbrot_set) # print(f"Execution time for Python Mandelbrot (cython): {execution_time:.0f} ms") ``` ## Mandelbrot in Mojo ```mojo from Benchmark import Benchmark from DType import DType from Memory import memset_zero from Object import object, Attr from Pointer import DTypePointer, Pointer from Random import rand from Range import range from TargetInfo import dtype_sizeof from Time import now from Complex import ComplexSIMD as ComplexGenericSIMD ``` ```mojo struct Matrix: var data: DTypePointer[DType.si64] var rows: Int var cols: Int var rc: Pointer[Int] fn __init__(self&, cols: Int, rows: Int): self.data = DTypePointer[DType.si64].alloc(rows * cols) self.rows = rows self.cols = cols self.rc = Pointer[Int].alloc(1) self.rc.store(1) fn __copyinit__(self&, other: Self): other._inc_rc() self.data = other.data self.rc = other.rc self.rows = other.rows self.cols = other.cols fn __del__(owned self): self._dec_rc() fn _get_rc(self) -> Int: return self.rc.load() fn _dec_rc(self): let rc = self._get_rc() if rc > 1: self.rc.store(rc - 1) return self._free() fn _inc_rc(self): let rc = self._get_rc() self.rc.store(rc + 1) fn _free(self): self.data.free() self.rc.free() @always_inline fn __getitem__(self, col: Int, row: Int) -> SI64: return self.load[1](col, row) @always_inline fn load[nelts:Int](self, col: Int, row: Int) -> SIMD[DType.si64, nelts]: return self.data.simd_load[nelts](row * self.cols + col) @always_inline fn __setitem__(self, col: Int, row: Int, val: SI64): return self.store[1](col, row, val) @always_inline fn store[nelts:Int](self, col: Int, row: Int, val: SIMD[DType.si64, nelts]): self.data.simd_store[nelts](row * self.cols + col, val) def to_numpy(self) -> PythonObject: let np = Python.import_module("numpy") let numpy_array = np.zeros((self.rows, self.cols), np.uint32) for col in range(self.cols): for row in range(self.rows): numpy_array.itemset((row, col), self[col, row].cast[DType.f32]()) return numpy_array ``` ```mojo @register_passable("trivial") struct Complex: var real: F32 var imag: F32 fn __init__(real: F32, imag: F32) -> Self: return Self {real: real, imag: imag} fn __add__(lhs, rhs: Self) -> Self: return Self(lhs.real + rhs.real, lhs.imag + rhs.imag) fn __mul__(lhs, rhs: Self) -> Self: return Self( lhs.real * rhs.real - lhs.imag * rhs.imag, lhs.real * rhs.imag + lhs.imag * rhs.real, ) fn norm(self) -> F32: return self.real * self.real + self.imag * self.imag ``` Then we can write the core Mandelbrot algorithm, which involves computing an iterative complex function for each pixel until it "escapes" the complex circle of radius 2, counting the number of iterations to escape. $$z_{i+1} = z_i^2 + c$$ ```mojo alias xmin: F32 = -2.25 alias xmax: F32 = 0.75 alias xn = 450 alias ymin: F32 = -1.25 alias ymax: F32 = 1.25 alias yn = 375 # Compute the number of steps to escape. def mandelbrot_kernel(c: Complex) -> Int: max_iter = 200 z = c for i in range(max_iter): z = z * z + c if z.norm() > 4: return i return max_iter def compute_mandelbrot() -> Matrix: # create a matrix. Each element of the matrix corresponds to a pixel result = Matrix(xn, yn) dx = (xmax - xmin) / xn dy = (ymax - ymin) / yn y = ymin for j in range(yn): x = xmin for i in range(xn): result[i, j] = mandelbrot_kernel(Complex(x, y)) x += dx y += dy return result ``` Plotting the number of iterations to escape with some color gives us the canonical Mandelbrot set plot. To render it we can directly leverage Python's `matplotlib` right from Mojo! ```mojo def make_plot(m: Matrix): np = Python.import_module("numpy") plt = Python.import_module("matplotlib.pyplot") colors = Python.import_module("matplotlib.colors") dpi = 32 width = 5 height = 5 * yn // xn fig = plt.figure(1, [width, height], dpi) ax = fig.add_axes([0.0, 0.0, 1.0, 1.0], False, 1) light = colors.LightSource(315, 10, 0, 1, 1, 0) image = light.shade(m.to_numpy(), plt.cm.hot, colors.PowerNorm(0.3), "hsv", 0, 0, 1.5) plt.imshow(image) plt.axis("off") plt.show() ``` ```mojo let eval_begin: Int = now() # This is in nanoseconds let mandelbrot_set = compute_mandelbrot() let eval_end: Int = now() let execution_time = (eval_end - eval_begin) // 1000000 make_plot(mandelbrot_set) print('Execution time for Mojo Mandelbrot: ', execution_time, 'ms') ``` ![output_27_0](https://user-images.githubusercontent.com/6831355/237290304-35799740-73b9-4ac4-959d-da0fa8391a39.png) Execution time for Mojo Mandelbrot: 27 ms ## Vectorizing Mandelbrot We showed a naive implementation of the Mandelbrot algorithm, but there are two things we can do to speed it up. We can early-stop the loop iteration when a pixel is known to have escaped, and we can leverage Mojo's access to hardware by vectorizing the loop, computing multiple pixels simultaneously. To do that we will use the `vectorize` higher order generator. We start by defining our main iteration loop in a vectorized fashion ```mojo fn mandelbrot_kernel_simd[simd_width:Int](c: ComplexGenericSIMD[DType.f32, simd_width]) -> SIMD[DType.si64, simd_width]: var z = c var nv = SIMD[DType.si64, simd_width](0) var escape_mask = SIMD[DType.bool, simd_width](0) var i = 200 while i != 0 and not escape_mask: z = z*z + c # Only update elements that haven't escaped yet escape_mask = escape_mask.select(escape_mask, z.norm() > 4) nv = escape_mask.select(nv, nv + 1) i -= 1 return nv ``` The above function is parameterized on the simd_width and processes simd_width pixels. It only escapes once all pixels within the vector lane are done. We can use the same iteration loop as above, but this time we vectorize within each row instead. We use the `vectorize` generator to make this a simple function call. ```mojo from Functional import vectorize from Math import iota from TargetInfo import dtype_simd_width def compute_mandelbrot_simd() -> Matrix: # create a matrix. Each element of the matrix corresponds to a pixel var result = Matrix(xn, yn) let dx = (xmax - xmin) / xn let dy = (ymax - ymin) / yn var y = ymin alias simd_width = dtype_simd_width[DType.f32]() for row in range(yn): var x = xmin @parameter fn _process_simd_element[simd_width:Int](col: Int): let c = ComplexGenericSIMD[DType.f32, simd_width](dx*iota[simd_width, DType.f32]() + x, SIMD[DType.f32, simd_width](y)) result.store[simd_width](col, row, mandelbrot_kernel_simd[simd_width](c)) x += simd_width*dx vectorize[simd_width, _process_simd_element](xn) y += dy return result ``` ```mojo let eval_begin: Int = now() let mandelbrot_set = compute_mandelbrot_simd() let eval_end: Int = now() let execution_time = (eval_end - eval_begin) // 1000000 make_plot(mandelbrot_set) print('Execution time for Mojo Mandelbrot (vectorized): ', execution_time, 'ms') ``` ![output_34_0](https://user-images.githubusercontent.com/6831355/237290305-06f18b52-3b1b-4c11-97bc-96330b852eea.png) Execution time for Mojo Mandelbrot (vectorized): 2 ms ## Parallelizing Mandelbrot While the vectorized implementation above is efficient, we can get better performance by parallelizing on the rows. This again is simple in Mojo using the `parallelize` higher order function. Only the function that performs the invocation needs to change. ```mojo from Functional import parallelize def compute_mandelbrot_simd_parallel() -> Matrix: # create a matrix. Each element of the matrix corresponds to a pixel var result = Matrix(xn, yn) let dx = (xmax - xmin) / xn let dy = (ymax - ymin) / yn alias simd_width = dtype_simd_width[DType.f32]() @parameter fn _process_row(row:Int): var y = ymin + dy*row var x = xmin @parameter fn _process_simd_element[simd_width:Int](col: Int): let c = ComplexGenericSIMD[DType.f32, simd_width](dx*iota[simd_width, DType.f32]() + x, SIMD[DType.f32, simd_width](y)) result.store[simd_width](col, row, mandelbrot_kernel_simd[simd_width](c)) x += simd_width*dx vectorize[simd_width, _process_simd_element](xn) parallelize[_process_row](yn) return result ``` ```mojo let eval_begin: Int = now() let mandelbrot_set = compute_mandelbrot_simd_parallel() let eval_end: Int = now() let execution_time = (eval_end - eval_begin) // 1000000 make_plot(mandelbrot_set) print('Execution time for Mojo Mandelbrot (vectorized-parallelized): ', execution_time, 'ms') ``` ![output_38_0](https://user-images.githubusercontent.com/6831355/237290310-6cc95333-46fb-4ada-a328-fab00f8b36ff.png) Execution time for Mojo Mandelbrot (vectorized-parallelized): 4 ms