Skip to content

Instantly share code, notes, and snippets.

@YouJiacheng
Last active August 26, 2025 08:24
Show Gist options
  • Select an option

  • Save YouJiacheng/393c90cbdc23b09d5688815ba382288b to your computer and use it in GitHub Desktop.

Select an option

Save YouJiacheng/393c90cbdc23b09d5688815ba382288b to your computer and use it in GitHub Desktop.

Revisions

  1. YouJiacheng revised this gist Feb 24, 2025. 1 changed file with 13 additions and 16 deletions.
    29 changes: 13 additions & 16 deletions hp-opt.py
    Original file line number Diff line number Diff line change
    @@ -1,6 +1,3 @@
    from typing import Any


    from functools import partial

    import jax
    @@ -10,6 +7,7 @@

    def poly(x: jnp.ndarray, w: jnp.ndarray):
    assert w.shape == (3,)
    w = w.astype(jnp.float32)
    return w[0] * x + w[1] * x**3 + w[2] * x**5


    @@ -35,17 +33,15 @@ def optimize_w(
    debug: bool = False,
    ):
    def loss(w_seq: jnp.ndarray):
    w_seq = w_seq + jax.lax.stop_gradient(
    jnp.round(1024 * w_seq, 0) / 1024 - w_seq
    ) # STE
    w_seq = w_seq.astype(jnp.bfloat16)
    xs = (jnp.arange(2048) + 1) / 2048
    *zs, ys = jax.vmap(poly_chain, in_axes=(0, None))(xs, w_seq)
    y_max = jnp.amax(ys)
    y_min = jnp.amin(jnp.where(xs > 1 / 128, ys, jnp.inf))
    diff_ratio = (y_max - y_min) / jnp.clip(y_max, min=1e-3)

    slope_xs = (jnp.arange(320) + 1) / 256
    min_ps: Any = jax.vmap(min_of_polys, in_axes=(0, None))(slope_xs, w_seq)
    min_ps = jax.vmap(min_of_polys, in_axes=(0, None))(slope_xs, w_seq)
    min_slope = jnp.amin(min_ps / slope_xs)

    z_max_seq = [jnp.amax(z) for z in zs]
    @@ -65,8 +61,8 @@ def loss(w_seq: jnp.ndarray):
    jax.debug.print("{x}", x=objectives)

    loss = -4.0 * obj_0
    loss += 16.0 * jnp.abs(obj_1 - 1)
    loss += 2.0 * jnp.clip(obj_2, min=-8)
    loss += 16.0 * jnp.square(obj_1 - 1)
    loss += 2.0 * jnp.clip(obj_2, min=-10)
    loss += -4.0 * jnp.clip(obj_3, max=1 / 2)
    loss += 64.0 * obj_4
    return loss, objectives
    @@ -81,31 +77,32 @@ def loss(w_seq: jnp.ndarray):
    def body_fn(carry: tuple[jnp.ndarray, optax.OptState], _):
    w_seq, opt_state = carry
    (_, objectives), grad = loss_and_grad_fn(w_seq)
    updates, opt_state = optimizer.update(grad, opt_state, params=w_seq)
    updates, opt_state = optimizer.update(grad, opt_state)
    w_seq = optax.apply_updates(w_seq, updates)
    return (w_seq, opt_state), objectives

    (w_seq, _), objectives = jax.lax.scan(body_fn, (w_seq, opt_state), length=n)
    return jnp.round(1024 * w_seq, 0) / 1024, objectives
    return w_seq, objectives


    def main():
    BASE = 128
    w_seq = jnp.array([[3.5, -6.04444444444, 2.84444444444]] * 6)
    for i in range(5):
    w_seq, objectives = optimize_w(w_seq, lr=2e-3, n=100000)
    print(w_seq * 1024)
    print(w_seq.astype(jnp.bfloat16) * BASE)
    print(i, [obj[-1].item() for obj in objectives])
    for i in range(5):
    w_seq, objectives = optimize_w(w_seq, lr=1e-3, n=100000)
    print(w_seq * 1024)
    print(w_seq.astype(jnp.bfloat16) * BASE)
    print(i, [obj[-1].item() for obj in objectives])
    for i in range(5):
    w_seq, objectives = optimize_w(w_seq, lr=5e-4, n=100000)
    print(w_seq * 1024)
    print(w_seq.astype(jnp.bfloat16) * BASE)
    print(i, [obj[-1].item() for obj in objectives])
    for i in range(5):
    for i in range(20):
    w_seq, objectives = optimize_w(w_seq, lr=1e-4, n=100000)
    print(w_seq * 1024)
    print(w_seq.astype(jnp.bfloat16) * BASE)
    print(i, [obj[-1].item() for obj in objectives])


  2. YouJiacheng revised this gist Feb 23, 2025. 1 changed file with 60 additions and 48 deletions.
    108 changes: 60 additions & 48 deletions hp-opt.py
    Original file line number Diff line number Diff line change
    @@ -1,3 +1,6 @@
    from typing import Any


    from functools import partial

    import jax
    @@ -11,90 +14,99 @@ def poly(x: jnp.ndarray, w: jnp.ndarray):


    def poly_chain(x: jnp.ndarray, w_seq: jnp.ndarray):
    y = x
    y = [x]
    for w in w_seq:
    y = poly(y, w)
    y.append(poly(y[-1], w))
    return y


    def poly_min(x: jnp.ndarray, w_seq: jnp.ndarray):
    def min_of_polys(x: jnp.ndarray, w_seq: jnp.ndarray):
    y = jnp.full_like(x, jnp.inf)
    for w in w_seq:
    y = jnp.minimum(y, poly(x, w))
    return y


    @partial(jax.jit, static_argnums=(3, 4))
    def optimize_bc(
    a_seq: jnp.ndarray,
    bc_seq: jnp.ndarray,
    @partial(jax.jit, static_argnums=(2, 3))
    def optimize_w(
    w_seq: jnp.ndarray,
    lr: float,
    n: int,
    debug: bool = False,
    ):
    def loss(bc_seq: jnp.ndarray):
    bc_seq = bc_seq + jax.lax.stop_gradient(jnp.round(1024 * bc_seq, 0) / 1024 - bc_seq) # STE
    xs = jnp.concat(
    [
    jnp.linspace(0, 1 / 128, 256),
    jnp.linspace(0, 1 / 16, 512),
    jnp.linspace(0, 1, 1024),
    jnp.linspace(1, 1 + 1 / 8, 256),
    ]
    )
    assert a_seq.shape == (bc_seq.shape[0],)
    w_seq = jnp.concat([a_seq[:, None], bc_seq], axis=1)
    ys = jax.vmap(poly_chain, in_axes=(0, None))(xs, w_seq)
    def loss(w_seq: jnp.ndarray):
    w_seq = w_seq + jax.lax.stop_gradient(
    jnp.round(1024 * w_seq, 0) / 1024 - w_seq
    ) # STE
    xs = (jnp.arange(2048) + 1) / 2048
    *zs, ys = jax.vmap(poly_chain, in_axes=(0, None))(xs, w_seq)
    y_max = jnp.amax(ys)
    y_min = jnp.amin(jnp.where(xs > 1 / 128, ys, jnp.inf))
    diff_ratio = (y_max - y_min) / jnp.clip(y_max, min=1e-3)

    min_zs = jax.vmap(poly_min, in_axes=(0, None))(xs, w_seq)
    min_slope = jnp.amin(
    jnp.where(xs > 1 / 128, min_zs / jnp.clip(xs, min=1 / 128), jnp.inf)
    slope_xs = (jnp.arange(320) + 1) / 256
    min_ps: Any = jax.vmap(min_of_polys, in_axes=(0, None))(slope_xs, w_seq)
    min_slope = jnp.amin(min_ps / slope_xs)

    z_max_seq = [jnp.amax(z) for z in zs]
    max_next_excess = sum(
    jnp.clip(poly(z + 1 / 16, w) - z, min=0) for z, w in zip(z_max_seq, w_seq)
    )

    obj_1 = y_max # smaller is better
    obj_0 = ys[0] / y_max # larger is better
    obj_1 = y_max # closer to 1 is better
    obj_2 = jnp.log2(diff_ratio) # smaller is better
    obj_3 = min_slope # larger is better
    obj_4 = max_next_excess # smaller is better

    objectives = (obj_0, obj_1, obj_2, obj_3, obj_4)

    if debug:
    jax.debug.print("{x} {y} {z}", x=obj_1, y=obj_2, z=obj_3)
    return jnp.clip(obj_1, min=1) + 4.0 * jnp.clip(
    obj_2, min=jnp.log2(1 / 1024)
    ) + -8.0 * jnp.clip(obj_3, max=0.5), (obj_1, obj_2, obj_3)
    jax.debug.print("{x}", x=objectives)

    loss = -4.0 * obj_0
    loss += 16.0 * jnp.abs(obj_1 - 1)
    loss += 2.0 * jnp.clip(obj_2, min=-8)
    loss += -4.0 * jnp.clip(obj_3, max=1 / 2)
    loss += 64.0 * obj_4
    return loss, objectives

    loss_and_grad_fn = jax.value_and_grad(loss, argnums=0, has_aux=True)
    optimizer = optax.chain(
    optax.adam(learning_rate=lr),
    optax.clip_by_global_norm(1.0),
    )
    opt_state = optimizer.init(bc_seq)
    opt_state = optimizer.init(w_seq)

    def body_fn(carry: tuple[jnp.ndarray, optax.OptState], _):
    bc_seq, opt_state = carry
    (_, (obj_1, obj_2, obj_3)), grad = loss_and_grad_fn(bc_seq)
    updates, opt_state = optimizer.update(grad, opt_state, params=bc_seq)
    bc_seq = optax.apply_updates(bc_seq, updates)
    return (bc_seq, opt_state), (obj_1, obj_2, obj_3)
    w_seq, opt_state = carry
    (_, objectives), grad = loss_and_grad_fn(w_seq)
    updates, opt_state = optimizer.update(grad, opt_state, params=w_seq)
    w_seq = optax.apply_updates(w_seq, updates)
    return (w_seq, opt_state), objectives

    (bc_seq, _), losses = jax.lax.scan(body_fn, (bc_seq, opt_state), length=n)
    return jnp.round(1024 * bc_seq, 0) / 1024, losses
    (w_seq, _), objectives = jax.lax.scan(body_fn, (w_seq, opt_state), length=n)
    return jnp.round(1024 * w_seq, 0) / 1024, objectives


    def main():
    a_seq = jnp.array([3.5] * 6)
    bc_seq = jnp.array([[-6.04444444444, 2.84444444444]] * 6)
    for i in range(20):
    bc_seq, objectives = optimize_bc(a_seq, bc_seq, lr=5e-4, n=100000)
    print(bc_seq * 1024)
    print(i, objectives[0][-1], objectives[1][-1], objectives[2][-1])
    for i in range(20):
    bc_seq, objectives = optimize_bc(a_seq, bc_seq, lr=1e-4, n=100000)
    print(bc_seq * 1024)
    print(i, objectives[0][-1], objectives[1][-1], objectives[2][-1])
    bc_seq, objectives = optimize_bc(a_seq, bc_seq, lr=2e-5, n=10000, debug=True)
    print(bc_seq * 1024)
    w_seq = jnp.array([[3.5, -6.04444444444, 2.84444444444]] * 6)
    for i in range(5):
    w_seq, objectives = optimize_w(w_seq, lr=2e-3, n=100000)
    print(w_seq * 1024)
    print(i, [obj[-1].item() for obj in objectives])
    for i in range(5):
    w_seq, objectives = optimize_w(w_seq, lr=1e-3, n=100000)
    print(w_seq * 1024)
    print(i, [obj[-1].item() for obj in objectives])
    for i in range(5):
    w_seq, objectives = optimize_w(w_seq, lr=5e-4, n=100000)
    print(w_seq * 1024)
    print(i, [obj[-1].item() for obj in objectives])
    for i in range(5):
    w_seq, objectives = optimize_w(w_seq, lr=1e-4, n=100000)
    print(w_seq * 1024)
    print(i, [obj[-1].item() for obj in objectives])


    if __name__ == "__main__":
  3. YouJiacheng revised this gist Feb 23, 2025. 1 changed file with 15 additions and 15 deletions.
    30 changes: 15 additions & 15 deletions hp-opt.py
    Original file line number Diff line number Diff line change
    @@ -33,7 +33,7 @@ def optimize_bc(
    debug: bool = False,
    ):
    def loss(bc_seq: jnp.ndarray):
    bc_seq = bc_seq + jax.lax.stop_gradient(jnp.round(bc_seq, 2) - bc_seq) # STE
    bc_seq = bc_seq + jax.lax.stop_gradient(jnp.round(1024 * bc_seq, 0) / 1024 - bc_seq) # STE
    xs = jnp.concat(
    [
    jnp.linspace(0, 1 / 128, 256),
    @@ -55,14 +55,14 @@ def loss(bc_seq: jnp.ndarray):
    )

    obj_1 = y_max # smaller is better
    obj_2 = diff_ratio # smaller is better
    obj_2 = jnp.log2(diff_ratio) # smaller is better
    obj_3 = min_slope # larger is better

    if debug:
    jax.debug.print("{x} {y} {z}", x=obj_1, y=obj_2, z=obj_3)
    return jnp.clip(obj_1, min=1) + 8.0 * jnp.clip(
    obj_2, min=1 / 32
    ) + -4.0 * jnp.clip(obj_3, max=0.5), (obj_1, obj_2, obj_3)
    return jnp.clip(obj_1, min=1) + 4.0 * jnp.clip(
    obj_2, min=jnp.log2(1 / 1024)
    ) + -8.0 * jnp.clip(obj_3, max=0.5), (obj_1, obj_2, obj_3)

    loss_and_grad_fn = jax.value_and_grad(loss, argnums=0, has_aux=True)
    optimizer = optax.chain(
    @@ -79,22 +79,22 @@ def body_fn(carry: tuple[jnp.ndarray, optax.OptState], _):
    return (bc_seq, opt_state), (obj_1, obj_2, obj_3)

    (bc_seq, _), losses = jax.lax.scan(body_fn, (bc_seq, opt_state), length=n)
    return jnp.round(bc_seq, 2), losses
    return jnp.round(1024 * bc_seq, 0) / 1024, losses


    def main():
    a_seq = jnp.array([3.6] * 5)
    bc_seq = jnp.array([[-5.86805555556, 2.82118055556]] * 5)
    for i in range(10):
    bc_seq, objectives = optimize_bc(a_seq, bc_seq, lr=1e-3, n=100000)
    print(bc_seq)
    a_seq = jnp.array([3.5] * 6)
    bc_seq = jnp.array([[-6.04444444444, 2.84444444444]] * 6)
    for i in range(20):
    bc_seq, objectives = optimize_bc(a_seq, bc_seq, lr=5e-4, n=100000)
    print(bc_seq * 1024)
    print(i, objectives[0][-1], objectives[1][-1], objectives[2][-1])
    for i in range(10):
    bc_seq, objectives = optimize_bc(a_seq, bc_seq, lr=2e-4, n=100000)
    print(bc_seq)
    for i in range(20):
    bc_seq, objectives = optimize_bc(a_seq, bc_seq, lr=1e-4, n=100000)
    print(bc_seq * 1024)
    print(i, objectives[0][-1], objectives[1][-1], objectives[2][-1])
    bc_seq, objectives = optimize_bc(a_seq, bc_seq, lr=2e-5, n=10000, debug=True)
    print(bc_seq)
    print(bc_seq * 1024)


    if __name__ == "__main__":
  4. YouJiacheng revised this gist Feb 17, 2025. 1 changed file with 37 additions and 25 deletions.
    62 changes: 37 additions & 25 deletions hp-opt.py
    Original file line number Diff line number Diff line change
    @@ -17,6 +17,13 @@ def poly_chain(x: jnp.ndarray, w_seq: jnp.ndarray):
    return y


    def poly_min(x: jnp.ndarray, w_seq: jnp.ndarray):
    y = jnp.full_like(x, jnp.inf)
    for w in w_seq:
    y = jnp.minimum(y, poly(x, w))
    return y


    @partial(jax.jit, static_argnums=(3, 4))
    def optimize_bc(
    a_seq: jnp.ndarray,
    @@ -26,27 +33,36 @@ def optimize_bc(
    debug: bool = False,
    ):
    def loss(bc_seq: jnp.ndarray):
    bc_seq = bc_seq + jax.lax.stop_gradient(jnp.round(bc_seq, 2) - bc_seq) # STE
    xs = jnp.concat(
    [
    jnp.linspace(0, 0.01, 1000),
    jnp.linspace(0, 0.1, 1000),
    jnp.linspace(0, 1, 1000),
    jnp.linspace(0, 1 / 128, 256),
    jnp.linspace(0, 1 / 16, 512),
    jnp.linspace(0, 1, 1024),
    jnp.linspace(1, 1 + 1 / 8, 256),
    ]
    )
    assert a_seq.shape == (bc_seq.shape[0],)
    w_seq = jnp.concat([a_seq[:, None], bc_seq], axis=1)
    ys = jax.vmap(poly_chain, in_axes=(0, None))(xs, w_seq)
    assert ys.shape == (3000,)
    y_max = jnp.amax(ys)
    y_min = jnp.amin(jnp.where(xs > 0.01, ys, jnp.inf))
    y_min = jnp.amin(jnp.where(xs > 1 / 128, ys, jnp.inf))
    diff_ratio = (y_max - y_min) / jnp.clip(y_max, min=1e-3)
    loss_1 = y_max # smaller is better
    loss_2 = jnp.clip(diff_ratio, min=0.1025) # smaller is better
    loss_3 = jnp.max((w_seq[:, 0] - (-6.0)) ** 2) # regularizer
    loss_4 = jnp.max((w_seq[:, 1] / w_seq[:, 0]) ** 2) # regularizer

    min_zs = jax.vmap(poly_min, in_axes=(0, None))(xs, w_seq)
    min_slope = jnp.amin(
    jnp.where(xs > 1 / 128, min_zs / jnp.clip(xs, min=1 / 128), jnp.inf)
    )

    obj_1 = y_max # smaller is better
    obj_2 = diff_ratio # smaller is better
    obj_3 = min_slope # larger is better

    if debug:
    jax.debug.print("{x} {y}", x=loss_1, y=loss_2)
    return loss_1 + 10.0 * loss_2 + 1e-4 * loss_3 + 1e-4 * loss_4, (loss_1, loss_2)
    jax.debug.print("{x} {y} {z}", x=obj_1, y=obj_2, z=obj_3)
    return jnp.clip(obj_1, min=1) + 8.0 * jnp.clip(
    obj_2, min=1 / 32
    ) + -4.0 * jnp.clip(obj_3, max=0.5), (obj_1, obj_2, obj_3)

    loss_and_grad_fn = jax.value_and_grad(loss, argnums=0, has_aux=True)
    optimizer = optax.chain(
    @@ -57,32 +73,28 @@ def loss(bc_seq: jnp.ndarray):

    def body_fn(carry: tuple[jnp.ndarray, optax.OptState], _):
    bc_seq, opt_state = carry
    (_, (loss_1, loss_2)), grad = loss_and_grad_fn(bc_seq)
    (_, (obj_1, obj_2, obj_3)), grad = loss_and_grad_fn(bc_seq)
    updates, opt_state = optimizer.update(grad, opt_state, params=bc_seq)
    bc_seq = optax.apply_updates(bc_seq, updates)
    return (bc_seq, opt_state), (loss_1, loss_2)
    return (bc_seq, opt_state), (obj_1, obj_2, obj_3)

    (bc_seq, _), losses = jax.lax.scan(body_fn, (bc_seq, opt_state), length=n)
    return bc_seq, losses
    return jnp.round(bc_seq, 2), losses


    def main():
    a_seq = jnp.array([3.6] * 5)
    bc_seq = jnp.array([[-5.86805555556, 2.82118055556]] * 5)
    for i in range(10):
    bc_seq, losses = optimize_bc(a_seq, bc_seq, lr=2e-3, n=100000)
    bc_seq, objectives = optimize_bc(a_seq, bc_seq, lr=1e-3, n=100000)
    print(bc_seq)
    print(i, losses[0][-1], losses[1][-1])
    for i in range(20):
    bc_seq, losses = optimize_bc(a_seq, bc_seq, lr=1e-3, n=100000)
    print(bc_seq)
    print(i, losses[0][-1], losses[1][-1])
    for i in range(20):
    bc_seq, losses = optimize_bc(a_seq, bc_seq, lr=2e-4, n=100000)
    print(i, objectives[0][-1], objectives[1][-1], objectives[2][-1])
    for i in range(10):
    bc_seq, objectives = optimize_bc(a_seq, bc_seq, lr=2e-4, n=100000)
    print(bc_seq)
    print(i, losses[0][-1], losses[1][-1])
    # bc_seq, losses = optimize_bc(a_seq, bc_seq, lr=2e-5, n=10000, debug=True)
    # print(bc_seq)
    print(i, objectives[0][-1], objectives[1][-1], objectives[2][-1])
    bc_seq, objectives = optimize_bc(a_seq, bc_seq, lr=2e-5, n=10000, debug=True)
    print(bc_seq)


    if __name__ == "__main__":
  5. YouJiacheng created this gist Feb 17, 2025.
    89 changes: 89 additions & 0 deletions hp-opt.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,89 @@
    from functools import partial

    import jax
    import jax.numpy as jnp
    import optax


    def poly(x: jnp.ndarray, w: jnp.ndarray):
    assert w.shape == (3,)
    return w[0] * x + w[1] * x**3 + w[2] * x**5


    def poly_chain(x: jnp.ndarray, w_seq: jnp.ndarray):
    y = x
    for w in w_seq:
    y = poly(y, w)
    return y


    @partial(jax.jit, static_argnums=(3, 4))
    def optimize_bc(
    a_seq: jnp.ndarray,
    bc_seq: jnp.ndarray,
    lr: float,
    n: int,
    debug: bool = False,
    ):
    def loss(bc_seq: jnp.ndarray):
    xs = jnp.concat(
    [
    jnp.linspace(0, 0.01, 1000),
    jnp.linspace(0, 0.1, 1000),
    jnp.linspace(0, 1, 1000),
    ]
    )
    assert a_seq.shape == (bc_seq.shape[0],)
    w_seq = jnp.concat([a_seq[:, None], bc_seq], axis=1)
    ys = jax.vmap(poly_chain, in_axes=(0, None))(xs, w_seq)
    assert ys.shape == (3000,)
    y_max = jnp.amax(ys)
    y_min = jnp.amin(jnp.where(xs > 0.01, ys, jnp.inf))
    diff_ratio = (y_max - y_min) / jnp.clip(y_max, min=1e-3)
    loss_1 = y_max # smaller is better
    loss_2 = jnp.clip(diff_ratio, min=0.1025) # smaller is better
    loss_3 = jnp.max((w_seq[:, 0] - (-6.0)) ** 2) # regularizer
    loss_4 = jnp.max((w_seq[:, 1] / w_seq[:, 0]) ** 2) # regularizer
    if debug:
    jax.debug.print("{x} {y}", x=loss_1, y=loss_2)
    return loss_1 + 10.0 * loss_2 + 1e-4 * loss_3 + 1e-4 * loss_4, (loss_1, loss_2)

    loss_and_grad_fn = jax.value_and_grad(loss, argnums=0, has_aux=True)
    optimizer = optax.chain(
    optax.adam(learning_rate=lr),
    optax.clip_by_global_norm(1.0),
    )
    opt_state = optimizer.init(bc_seq)

    def body_fn(carry: tuple[jnp.ndarray, optax.OptState], _):
    bc_seq, opt_state = carry
    (_, (loss_1, loss_2)), grad = loss_and_grad_fn(bc_seq)
    updates, opt_state = optimizer.update(grad, opt_state, params=bc_seq)
    bc_seq = optax.apply_updates(bc_seq, updates)
    return (bc_seq, opt_state), (loss_1, loss_2)

    (bc_seq, _), losses = jax.lax.scan(body_fn, (bc_seq, opt_state), length=n)
    return bc_seq, losses


    def main():
    a_seq = jnp.array([3.6] * 5)
    bc_seq = jnp.array([[-5.86805555556, 2.82118055556]] * 5)
    for i in range(10):
    bc_seq, losses = optimize_bc(a_seq, bc_seq, lr=2e-3, n=100000)
    print(bc_seq)
    print(i, losses[0][-1], losses[1][-1])
    for i in range(20):
    bc_seq, losses = optimize_bc(a_seq, bc_seq, lr=1e-3, n=100000)
    print(bc_seq)
    print(i, losses[0][-1], losses[1][-1])
    for i in range(20):
    bc_seq, losses = optimize_bc(a_seq, bc_seq, lr=2e-4, n=100000)
    print(bc_seq)
    print(i, losses[0][-1], losses[1][-1])
    # bc_seq, losses = optimize_bc(a_seq, bc_seq, lr=2e-5, n=10000, debug=True)
    # print(bc_seq)


    if __name__ == "__main__":
    main()