import numpy as np import jax import jax.numpy as jnp import elegy import optax