from functools import partial import jax import jax.numpy as jnp import optax def poly(x: jnp.ndarray, w: jnp.ndarray): assert w.shape == (3,) w = w.astype(jnp.float32) return w[0] * x + w[1] * x**3 + w[2] * x**5 def poly_chain(x: jnp.ndarray, w_seq: jnp.ndarray): y = [x] for w in w_seq: y.append(poly(y[-1], w)) return y def min_of_polys(x: jnp.ndarray, w_seq: jnp.ndarray): y = jnp.full_like(x, jnp.inf) for w in w_seq: y = jnp.minimum(y, poly(x, w)) return y @partial(jax.jit, static_argnums=(2, 3)) def optimize_w( w_seq: jnp.ndarray, lr: float, n: int, debug: bool = False, ): def loss(w_seq: jnp.ndarray): w_seq = w_seq.astype(jnp.bfloat16) xs = (jnp.arange(2048) + 1) / 2048 *zs, ys = jax.vmap(poly_chain, in_axes=(0, None))(xs, w_seq) y_max = jnp.amax(ys) y_min = jnp.amin(jnp.where(xs > 1 / 128, ys, jnp.inf)) diff_ratio = (y_max - y_min) / jnp.clip(y_max, min=1e-3) slope_xs = (jnp.arange(320) + 1) / 256 min_ps = jax.vmap(min_of_polys, in_axes=(0, None))(slope_xs, w_seq) min_slope = jnp.amin(min_ps / slope_xs) z_max_seq = [jnp.amax(z) for z in zs] max_next_excess = sum( jnp.clip(poly(z + 1 / 16, w) - z, min=0) for z, w in zip(z_max_seq, w_seq) ) obj_0 = ys[0] / y_max # larger is better obj_1 = y_max # closer to 1 is better obj_2 = jnp.log2(diff_ratio) # smaller is better obj_3 = min_slope # larger is better obj_4 = max_next_excess # smaller is better objectives = (obj_0, obj_1, obj_2, obj_3, obj_4) if debug: jax.debug.print("{x}", x=objectives) loss = -4.0 * obj_0 loss += 16.0 * jnp.square(obj_1 - 1) loss += 2.0 * jnp.clip(obj_2, min=-10) loss += -4.0 * jnp.clip(obj_3, max=1 / 2) loss += 64.0 * obj_4 return loss, objectives loss_and_grad_fn = jax.value_and_grad(loss, argnums=0, has_aux=True) optimizer = optax.chain( optax.adam(learning_rate=lr), optax.clip_by_global_norm(1.0), ) opt_state = optimizer.init(w_seq) def body_fn(carry: tuple[jnp.ndarray, optax.OptState], _): w_seq, opt_state = carry (_, objectives), grad = loss_and_grad_fn(w_seq) updates, opt_state = optimizer.update(grad, opt_state) w_seq = optax.apply_updates(w_seq, updates) return (w_seq, opt_state), objectives (w_seq, _), objectives = jax.lax.scan(body_fn, (w_seq, opt_state), length=n) return w_seq, objectives def main(): BASE = 128 w_seq = jnp.array([[3.5, -6.04444444444, 2.84444444444]] * 6) for i in range(5): w_seq, objectives = optimize_w(w_seq, lr=2e-3, n=100000) print(w_seq.astype(jnp.bfloat16) * BASE) print(i, [obj[-1].item() for obj in objectives]) for i in range(5): w_seq, objectives = optimize_w(w_seq, lr=1e-3, n=100000) print(w_seq.astype(jnp.bfloat16) * BASE) print(i, [obj[-1].item() for obj in objectives]) for i in range(5): w_seq, objectives = optimize_w(w_seq, lr=5e-4, n=100000) print(w_seq.astype(jnp.bfloat16) * BASE) print(i, [obj[-1].item() for obj in objectives]) for i in range(20): w_seq, objectives = optimize_w(w_seq, lr=1e-4, n=100000) print(w_seq.astype(jnp.bfloat16) * BASE) print(i, [obj[-1].item() for obj in objectives]) if __name__ == "__main__": main()