Last active
August 26, 2025 08:24
-
-
Save YouJiacheng/393c90cbdc23b09d5688815ba382288b to your computer and use it in GitHub Desktop.
Revisions
-
YouJiacheng revised this gist
Feb 24, 2025 . 1 changed file with 13 additions and 16 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,6 +1,3 @@ from functools import partial import jax @@ -10,6 +7,7 @@ def poly(x: jnp.ndarray, w: jnp.ndarray): assert w.shape == (3,) w = w.astype(jnp.float32) return w[0] * x + w[1] * x**3 + w[2] * x**5 @@ -35,17 +33,15 @@ def optimize_w( debug: bool = False, ): def loss(w_seq: jnp.ndarray): w_seq = w_seq.astype(jnp.bfloat16) xs = (jnp.arange(2048) + 1) / 2048 *zs, ys = jax.vmap(poly_chain, in_axes=(0, None))(xs, w_seq) y_max = jnp.amax(ys) y_min = jnp.amin(jnp.where(xs > 1 / 128, ys, jnp.inf)) diff_ratio = (y_max - y_min) / jnp.clip(y_max, min=1e-3) slope_xs = (jnp.arange(320) + 1) / 256 min_ps = jax.vmap(min_of_polys, in_axes=(0, None))(slope_xs, w_seq) min_slope = jnp.amin(min_ps / slope_xs) z_max_seq = [jnp.amax(z) for z in zs] @@ -65,8 +61,8 @@ def loss(w_seq: jnp.ndarray): jax.debug.print("{x}", x=objectives) loss = -4.0 * obj_0 loss += 16.0 * jnp.square(obj_1 - 1) loss += 2.0 * jnp.clip(obj_2, min=-10) loss += -4.0 * jnp.clip(obj_3, max=1 / 2) loss += 64.0 * obj_4 return loss, objectives @@ -81,31 +77,32 @@ def loss(w_seq: jnp.ndarray): def body_fn(carry: tuple[jnp.ndarray, optax.OptState], _): w_seq, opt_state = carry (_, objectives), grad = loss_and_grad_fn(w_seq) updates, opt_state = optimizer.update(grad, opt_state) w_seq = optax.apply_updates(w_seq, updates) return (w_seq, opt_state), objectives (w_seq, _), objectives = jax.lax.scan(body_fn, (w_seq, opt_state), length=n) return w_seq, objectives def main(): BASE = 128 w_seq = jnp.array([[3.5, -6.04444444444, 2.84444444444]] * 6) for i in range(5): w_seq, objectives = optimize_w(w_seq, lr=2e-3, n=100000) print(w_seq.astype(jnp.bfloat16) * BASE) print(i, [obj[-1].item() for obj in objectives]) for i in range(5): w_seq, objectives = optimize_w(w_seq, lr=1e-3, n=100000) print(w_seq.astype(jnp.bfloat16) * BASE) print(i, [obj[-1].item() for obj in objectives]) for i in range(5): w_seq, objectives = optimize_w(w_seq, lr=5e-4, n=100000) print(w_seq.astype(jnp.bfloat16) * BASE) print(i, [obj[-1].item() for obj in objectives]) for i in range(20): w_seq, objectives = optimize_w(w_seq, lr=1e-4, n=100000) print(w_seq.astype(jnp.bfloat16) * BASE) print(i, [obj[-1].item() for obj in objectives]) -
YouJiacheng revised this gist
Feb 23, 2025 . 1 changed file with 60 additions and 48 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,3 +1,6 @@ from typing import Any from functools import partial import jax @@ -11,90 +14,99 @@ def poly(x: jnp.ndarray, w: jnp.ndarray): def poly_chain(x: jnp.ndarray, w_seq: jnp.ndarray): y = [x] for w in w_seq: y.append(poly(y[-1], w)) return y def min_of_polys(x: jnp.ndarray, w_seq: jnp.ndarray): y = jnp.full_like(x, jnp.inf) for w in w_seq: y = jnp.minimum(y, poly(x, w)) return y @partial(jax.jit, static_argnums=(2, 3)) def optimize_w( w_seq: jnp.ndarray, lr: float, n: int, debug: bool = False, ): def loss(w_seq: jnp.ndarray): w_seq = w_seq + jax.lax.stop_gradient( jnp.round(1024 * w_seq, 0) / 1024 - w_seq ) # STE xs = (jnp.arange(2048) + 1) / 2048 *zs, ys = jax.vmap(poly_chain, in_axes=(0, None))(xs, w_seq) y_max = jnp.amax(ys) y_min = jnp.amin(jnp.where(xs > 1 / 128, ys, jnp.inf)) diff_ratio = (y_max - y_min) / jnp.clip(y_max, min=1e-3) slope_xs = (jnp.arange(320) + 1) / 256 min_ps: Any = jax.vmap(min_of_polys, in_axes=(0, None))(slope_xs, w_seq) min_slope = jnp.amin(min_ps / slope_xs) z_max_seq = [jnp.amax(z) for z in zs] max_next_excess = sum( jnp.clip(poly(z + 1 / 16, w) - z, min=0) for z, w in zip(z_max_seq, w_seq) ) obj_0 = ys[0] / y_max # larger is better obj_1 = y_max # closer to 1 is better obj_2 = jnp.log2(diff_ratio) # smaller is better obj_3 = min_slope # larger is better obj_4 = max_next_excess # smaller is better objectives = (obj_0, obj_1, obj_2, obj_3, obj_4) if debug: jax.debug.print("{x}", x=objectives) loss = -4.0 * obj_0 loss += 16.0 * jnp.abs(obj_1 - 1) loss += 2.0 * jnp.clip(obj_2, min=-8) loss += -4.0 * jnp.clip(obj_3, max=1 / 2) loss += 64.0 * obj_4 return loss, objectives loss_and_grad_fn = jax.value_and_grad(loss, argnums=0, has_aux=True) optimizer = optax.chain( optax.adam(learning_rate=lr), optax.clip_by_global_norm(1.0), ) opt_state = optimizer.init(w_seq) def body_fn(carry: tuple[jnp.ndarray, optax.OptState], _): w_seq, opt_state = carry (_, objectives), grad = loss_and_grad_fn(w_seq) updates, opt_state = optimizer.update(grad, opt_state, params=w_seq) w_seq = optax.apply_updates(w_seq, updates) return (w_seq, opt_state), objectives (w_seq, _), objectives = jax.lax.scan(body_fn, (w_seq, opt_state), length=n) return jnp.round(1024 * w_seq, 0) / 1024, objectives def main(): w_seq = jnp.array([[3.5, -6.04444444444, 2.84444444444]] * 6) for i in range(5): w_seq, objectives = optimize_w(w_seq, lr=2e-3, n=100000) print(w_seq * 1024) print(i, [obj[-1].item() for obj in objectives]) for i in range(5): w_seq, objectives = optimize_w(w_seq, lr=1e-3, n=100000) print(w_seq * 1024) print(i, [obj[-1].item() for obj in objectives]) for i in range(5): w_seq, objectives = optimize_w(w_seq, lr=5e-4, n=100000) print(w_seq * 1024) print(i, [obj[-1].item() for obj in objectives]) for i in range(5): w_seq, objectives = optimize_w(w_seq, lr=1e-4, n=100000) print(w_seq * 1024) print(i, [obj[-1].item() for obj in objectives]) if __name__ == "__main__": -
YouJiacheng revised this gist
Feb 23, 2025 . 1 changed file with 15 additions and 15 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -33,7 +33,7 @@ def optimize_bc( debug: bool = False, ): def loss(bc_seq: jnp.ndarray): bc_seq = bc_seq + jax.lax.stop_gradient(jnp.round(1024 * bc_seq, 0) / 1024 - bc_seq) # STE xs = jnp.concat( [ jnp.linspace(0, 1 / 128, 256), @@ -55,14 +55,14 @@ def loss(bc_seq: jnp.ndarray): ) obj_1 = y_max # smaller is better obj_2 = jnp.log2(diff_ratio) # smaller is better obj_3 = min_slope # larger is better if debug: jax.debug.print("{x} {y} {z}", x=obj_1, y=obj_2, z=obj_3) return jnp.clip(obj_1, min=1) + 4.0 * jnp.clip( obj_2, min=jnp.log2(1 / 1024) ) + -8.0 * jnp.clip(obj_3, max=0.5), (obj_1, obj_2, obj_3) loss_and_grad_fn = jax.value_and_grad(loss, argnums=0, has_aux=True) optimizer = optax.chain( @@ -79,22 +79,22 @@ def body_fn(carry: tuple[jnp.ndarray, optax.OptState], _): return (bc_seq, opt_state), (obj_1, obj_2, obj_3) (bc_seq, _), losses = jax.lax.scan(body_fn, (bc_seq, opt_state), length=n) return jnp.round(1024 * bc_seq, 0) / 1024, losses def main(): a_seq = jnp.array([3.5] * 6) bc_seq = jnp.array([[-6.04444444444, 2.84444444444]] * 6) for i in range(20): bc_seq, objectives = optimize_bc(a_seq, bc_seq, lr=5e-4, n=100000) print(bc_seq * 1024) print(i, objectives[0][-1], objectives[1][-1], objectives[2][-1]) for i in range(20): bc_seq, objectives = optimize_bc(a_seq, bc_seq, lr=1e-4, n=100000) print(bc_seq * 1024) print(i, objectives[0][-1], objectives[1][-1], objectives[2][-1]) bc_seq, objectives = optimize_bc(a_seq, bc_seq, lr=2e-5, n=10000, debug=True) print(bc_seq * 1024) if __name__ == "__main__": -
YouJiacheng revised this gist
Feb 17, 2025 . 1 changed file with 37 additions and 25 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -17,6 +17,13 @@ def poly_chain(x: jnp.ndarray, w_seq: jnp.ndarray): return y def poly_min(x: jnp.ndarray, w_seq: jnp.ndarray): y = jnp.full_like(x, jnp.inf) for w in w_seq: y = jnp.minimum(y, poly(x, w)) return y @partial(jax.jit, static_argnums=(3, 4)) def optimize_bc( a_seq: jnp.ndarray, @@ -26,27 +33,36 @@ def optimize_bc( debug: bool = False, ): def loss(bc_seq: jnp.ndarray): bc_seq = bc_seq + jax.lax.stop_gradient(jnp.round(bc_seq, 2) - bc_seq) # STE xs = jnp.concat( [ jnp.linspace(0, 1 / 128, 256), jnp.linspace(0, 1 / 16, 512), jnp.linspace(0, 1, 1024), jnp.linspace(1, 1 + 1 / 8, 256), ] ) assert a_seq.shape == (bc_seq.shape[0],) w_seq = jnp.concat([a_seq[:, None], bc_seq], axis=1) ys = jax.vmap(poly_chain, in_axes=(0, None))(xs, w_seq) y_max = jnp.amax(ys) y_min = jnp.amin(jnp.where(xs > 1 / 128, ys, jnp.inf)) diff_ratio = (y_max - y_min) / jnp.clip(y_max, min=1e-3) min_zs = jax.vmap(poly_min, in_axes=(0, None))(xs, w_seq) min_slope = jnp.amin( jnp.where(xs > 1 / 128, min_zs / jnp.clip(xs, min=1 / 128), jnp.inf) ) obj_1 = y_max # smaller is better obj_2 = diff_ratio # smaller is better obj_3 = min_slope # larger is better if debug: jax.debug.print("{x} {y} {z}", x=obj_1, y=obj_2, z=obj_3) return jnp.clip(obj_1, min=1) + 8.0 * jnp.clip( obj_2, min=1 / 32 ) + -4.0 * jnp.clip(obj_3, max=0.5), (obj_1, obj_2, obj_3) loss_and_grad_fn = jax.value_and_grad(loss, argnums=0, has_aux=True) optimizer = optax.chain( @@ -57,32 +73,28 @@ def loss(bc_seq: jnp.ndarray): def body_fn(carry: tuple[jnp.ndarray, optax.OptState], _): bc_seq, opt_state = carry (_, (obj_1, obj_2, obj_3)), grad = loss_and_grad_fn(bc_seq) updates, opt_state = optimizer.update(grad, opt_state, params=bc_seq) bc_seq = optax.apply_updates(bc_seq, updates) return (bc_seq, opt_state), (obj_1, obj_2, obj_3) (bc_seq, _), losses = jax.lax.scan(body_fn, (bc_seq, opt_state), length=n) return jnp.round(bc_seq, 2), losses def main(): a_seq = jnp.array([3.6] * 5) bc_seq = jnp.array([[-5.86805555556, 2.82118055556]] * 5) for i in range(10): bc_seq, objectives = optimize_bc(a_seq, bc_seq, lr=1e-3, n=100000) print(bc_seq) print(i, objectives[0][-1], objectives[1][-1], objectives[2][-1]) for i in range(10): bc_seq, objectives = optimize_bc(a_seq, bc_seq, lr=2e-4, n=100000) print(bc_seq) print(i, objectives[0][-1], objectives[1][-1], objectives[2][-1]) bc_seq, objectives = optimize_bc(a_seq, bc_seq, lr=2e-5, n=10000, debug=True) print(bc_seq) if __name__ == "__main__": -
YouJiacheng created this gist
Feb 17, 2025 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,89 @@ from functools import partial import jax import jax.numpy as jnp import optax def poly(x: jnp.ndarray, w: jnp.ndarray): assert w.shape == (3,) return w[0] * x + w[1] * x**3 + w[2] * x**5 def poly_chain(x: jnp.ndarray, w_seq: jnp.ndarray): y = x for w in w_seq: y = poly(y, w) return y @partial(jax.jit, static_argnums=(3, 4)) def optimize_bc( a_seq: jnp.ndarray, bc_seq: jnp.ndarray, lr: float, n: int, debug: bool = False, ): def loss(bc_seq: jnp.ndarray): xs = jnp.concat( [ jnp.linspace(0, 0.01, 1000), jnp.linspace(0, 0.1, 1000), jnp.linspace(0, 1, 1000), ] ) assert a_seq.shape == (bc_seq.shape[0],) w_seq = jnp.concat([a_seq[:, None], bc_seq], axis=1) ys = jax.vmap(poly_chain, in_axes=(0, None))(xs, w_seq) assert ys.shape == (3000,) y_max = jnp.amax(ys) y_min = jnp.amin(jnp.where(xs > 0.01, ys, jnp.inf)) diff_ratio = (y_max - y_min) / jnp.clip(y_max, min=1e-3) loss_1 = y_max # smaller is better loss_2 = jnp.clip(diff_ratio, min=0.1025) # smaller is better loss_3 = jnp.max((w_seq[:, 0] - (-6.0)) ** 2) # regularizer loss_4 = jnp.max((w_seq[:, 1] / w_seq[:, 0]) ** 2) # regularizer if debug: jax.debug.print("{x} {y}", x=loss_1, y=loss_2) return loss_1 + 10.0 * loss_2 + 1e-4 * loss_3 + 1e-4 * loss_4, (loss_1, loss_2) loss_and_grad_fn = jax.value_and_grad(loss, argnums=0, has_aux=True) optimizer = optax.chain( optax.adam(learning_rate=lr), optax.clip_by_global_norm(1.0), ) opt_state = optimizer.init(bc_seq) def body_fn(carry: tuple[jnp.ndarray, optax.OptState], _): bc_seq, opt_state = carry (_, (loss_1, loss_2)), grad = loss_and_grad_fn(bc_seq) updates, opt_state = optimizer.update(grad, opt_state, params=bc_seq) bc_seq = optax.apply_updates(bc_seq, updates) return (bc_seq, opt_state), (loss_1, loss_2) (bc_seq, _), losses = jax.lax.scan(body_fn, (bc_seq, opt_state), length=n) return bc_seq, losses def main(): a_seq = jnp.array([3.6] * 5) bc_seq = jnp.array([[-5.86805555556, 2.82118055556]] * 5) for i in range(10): bc_seq, losses = optimize_bc(a_seq, bc_seq, lr=2e-3, n=100000) print(bc_seq) print(i, losses[0][-1], losses[1][-1]) for i in range(20): bc_seq, losses = optimize_bc(a_seq, bc_seq, lr=1e-3, n=100000) print(bc_seq) print(i, losses[0][-1], losses[1][-1]) for i in range(20): bc_seq, losses = optimize_bc(a_seq, bc_seq, lr=2e-4, n=100000) print(bc_seq) print(i, losses[0][-1], losses[1][-1]) # bc_seq, losses = optimize_bc(a_seq, bc_seq, lr=2e-5, n=10000, debug=True) # print(bc_seq) if __name__ == "__main__": main()