Skip to content

Instantly share code, notes, and snippets.

@Edenhofer
Created March 16, 2023 18:31
Show Gist options
  • Select an option

  • Save Edenhofer/ece9a2e3e8c67721dbdd706b3966f04c to your computer and use it in GitHub Desktop.

Select an option

Save Edenhofer/ece9a2e3e8c67721dbdd706b3966f04c to your computer and use it in GitHub Desktop.

Revisions

  1. Edenhofer created this gist Mar 16, 2023.
    299 changes: 299 additions & 0 deletions eggholder_repro.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,299 @@
    # %%
    from functools import partial
    from typing import NamedTuple, Optional, Tuple, Union

    import jax
    from jax import lax
    from jax import numpy as jnp

    N_RESET = 20

    class CGResults(NamedTuple):
    x: jnp.ndarray
    nit: Union[int, jnp.ndarray]
    nfev: Union[int, jnp.ndarray] # number of matrix-evaluations
    info: Union[int, jnp.ndarray]
    success: Union[bool, jnp.ndarray]


    # The following is code adapted from Nicholas Mancuso to work with pytrees
    class _QuadSubproblemResult(NamedTuple):
    step: jnp.ndarray
    hits_boundary: Union[bool, jnp.ndarray]
    pred_f: Union[float, jnp.ndarray]
    nit: Union[int, jnp.ndarray]
    nfev: Union[int, jnp.ndarray]
    njev: Union[int, jnp.ndarray]
    nhev: Union[int, jnp.ndarray]
    success: Union[bool, jnp.ndarray]


    class _CGSteihaugState(NamedTuple):
    z: jnp.ndarray
    r: jnp.ndarray
    d: jnp.ndarray
    step: jnp.ndarray
    energy: Union[None, float, jnp.ndarray]
    hits_boundary: Union[bool, jnp.ndarray]
    done: Union[bool, jnp.ndarray]
    nit: Union[int, jnp.ndarray]
    nhev: Union[int, jnp.ndarray]


    def second_order_approx(
    p: jnp.ndarray,
    cur_val: Union[float, jnp.ndarray],
    g: jnp.ndarray,
    hessp_at_xk,
    ) -> Union[float, jnp.ndarray]:
    return cur_val + jnp.vdot(g, p) + 0.5 * jnp.vdot(p, hessp_at_xk(p))


    def get_boundaries_intersections(
    z: jnp.ndarray, d: jnp.ndarray, trust_radius: Union[float, jnp.ndarray]
    ):
    a = jnp.vdot(d, d)
    b = 2 * jnp.vdot(z, d)
    c = jnp.vdot(z, z) - trust_radius**2
    sqrt_discriminant = jnp.sqrt(b * b - 4 * a * c)

    aux = b + jnp.copysign(sqrt_discriminant, b)
    ta = -aux / (2 * a)
    tb = -2 * c / aux

    ra = jnp.where(ta < tb, ta, tb)
    rb = jnp.where(ta < tb, tb, ta)
    return (ra, rb)


    def _cg_steihaug_subproblem(
    cur_val: Union[float, jnp.ndarray],
    g: jnp.ndarray,
    hessp_at_xk,
    *,
    trust_radius: Union[float, jnp.ndarray],
    tr_norm_ord: Union[None, int, float, jnp.ndarray] = None,
    resnorm: Optional[float],
    absdelta: Optional[float] = None,
    norm_ord: Union[None, int, float, jnp.ndarray] = None,
    miniter: Union[None, int] = None,
    maxiter: Union[None, int] = None,
    ) -> _QuadSubproblemResult:
    from jax.experimental.host_callback import call

    tr_norm_ord = jnp.inf if tr_norm_ord is None else tr_norm_ord # taken from JAX
    norm_ord = 2 if norm_ord is None else norm_ord # TODO: change to 1
    maxiter_fallback = 20 * g.size # taken from SciPy's NewtonCG minimzer
    miniter = jnp.minimum(
    6, maxiter if maxiter is not None else maxiter_fallback
    ) if miniter is None else miniter
    maxiter = jnp.maximum(
    jnp.minimum(200, maxiter_fallback), miniter
    ) if maxiter is None else maxiter

    common_dtp = g.dtype
    eps = 6. * jnp.finfo(common_dtp).eps

    # second-order Taylor series approximation at the current values, gradient,
    # and hessian
    soa = partial(
    second_order_approx, cur_val=cur_val, g=g, hessp_at_xk=hessp_at_xk
    )

    # helpers for internal switches in the main CGSteihaug logic
    def noop(
    param: Tuple[_CGSteihaugState, Union[float, jnp.ndarray]]
    ) -> _CGSteihaugState:
    iterp, z_next = param
    return iterp

    def step1(
    param: Tuple[_CGSteihaugState, Union[float, jnp.ndarray]]
    ) -> _CGSteihaugState:
    iterp, z_next = param
    z, d, nhev = iterp.z, iterp.d, iterp.nhev

    ta, tb = get_boundaries_intersections(z, d, trust_radius)
    pa = z + ta * d
    pb = z + tb * d
    p_boundary = jnp.where(soa(pa) < soa(pb), pa, pb)
    return iterp._replace(
    step=p_boundary, nhev=nhev + 2, hits_boundary=True, done=True
    )

    def step2(
    param: Tuple[_CGSteihaugState, Union[float, jnp.ndarray]]
    ) -> _CGSteihaugState:
    iterp, z_next = param
    z, d = iterp.z, iterp.d

    ta, tb = get_boundaries_intersections(z, d, trust_radius)
    p_boundary = z + tb * d
    return iterp._replace(step=p_boundary, hits_boundary=True, done=True)

    def step3(
    param: Tuple[_CGSteihaugState, Union[float, jnp.ndarray]]
    ) -> _CGSteihaugState:
    iterp, z_next = param
    return iterp._replace(step=z_next, hits_boundary=False, done=True)

    # initialize the step
    p_origin = jnp.zeros_like(g)

    # init the state for the first iteration
    z = p_origin
    r = g
    d = -r
    energy = 0.
    init_param = _CGSteihaugState(
    z=z,
    r=r,
    d=d,
    step=p_origin,
    energy=energy,
    hits_boundary=False,
    done=maxiter == 0,
    nit=0,
    nhev=0
    )
    import jax

    # Search for the min of the approximation of the objective function.
    def body_f(iterp: _CGSteihaugState) -> _CGSteihaugState:
    z, r, d = iterp.z, iterp.r, iterp.d
    energy, nit = iterp.energy, iterp.nit

    nit += 1
    jax.debug.print("in body {nit} \\\\ 1 ::", nit=nit)

    Bd = hessp_at_xk(d)
    dBd = jnp.vdot(d, Bd)

    r_squared = jnp.vdot(r, r)
    alpha = r_squared / dBd
    z_next = z + alpha * d

    r_next = r + alpha * Bd
    r_next_squared = jnp.vdot(r_next, r_next)

    beta_next = r_next_squared / r_squared
    d_next = -r_next + beta_next * d
    jax.debug.print("in body {nit} \\\\ 2 ::", nit=nit)

    accept_z_next = nit >= maxiter
    jax.debug.print(
    "in body {nit} \\\\ 3 :: accept_z_next={accept_z_next}",
    nit=nit,
    accept_z_next=accept_z_next
    )
    if norm_ord == 2:
    r_next_norm = jnp.sqrt(r_next_squared)
    else:
    r_next_norm = jnp.linalg.norm(r_next, ord=norm_ord)
    accept_z_next |= r_next_norm < resnorm
    # Relative to a plain CG, `z_next` is negative
    energy_next = jnp.vdot((r_next + g) / 2, z_next)
    energy_diff = energy - energy_next
    if absdelta is not None:
    neg_energy_eps = -eps * jnp.abs(energy)
    accept_z_next |= (energy_diff >= neg_energy_eps
    ) & (energy_diff < absdelta) & (nit >= miniter)
    jax.debug.print("in body {nit} \\\\ 4 ::", nit=nit)

    # include a junk switch to catch the case where none should be executed
    z_next_norm = jnp.linalg.norm(z_next, ord=tr_norm_ord)
    jax.debug.print("in body {nit} \\\\ 5 :: pre-index", nit=nit)
    index = jnp.argmax(
    jnp.array(
    [False, dBd <= 0, z_next_norm >= trust_radius, accept_z_next]
    )
    )
    jax.debug.print("in body {nit} \\\\ 6 :: pre-switch {index}", nit=nit, index=index)
    iterp = lax.switch(index, [noop, step1, step2, step3], (iterp, z_next))
    jax.debug.print("in body {nit} \\\\ 7 :: post-switch", nit=nit)

    iterp = iterp._replace(
    z=z_next,
    r=r_next,
    d=d_next,
    energy=energy_next,
    nhev=iterp.nhev + 1,
    nit=nit
    )

    return iterp

    def cond_f(iterp: _CGSteihaugState) -> bool:
    jax.debug.print(
    "cond_f={c} maxiter={maxiter}", c=~iterp.done, maxiter=maxiter
    )
    return jnp.logical_not(iterp.done)

    # perform inner optimization to solve the constrained
    # quadratic subproblem using cg
    jax.debug.print("looped {result.done} {result}", result=init_param)
    result = lax.while_loop(cond_f, body_f, init_param)
    jax.debug.print("looped {result.done} {result}", result=result)

    pred_f = soa(result.step)
    result = _QuadSubproblemResult(
    step=result.step,
    hits_boundary=result.hits_boundary,
    pred_f=pred_f,
    nit=result.nit,
    nfev=0,
    njev=0,
    nhev=result.nhev + 1,
    success=True
    )

    return result


    def rosenbrock(np):
    def func(x):
    return jnp.sum(100. * jnp.diff(x)**2 + (1. - x[:-1])**2)

    return func


    def himmelblau(np):
    def func(p):
    x, y = p
    return (x**2 + y - 11.)**2 + (x + y**2 - 7.)**2

    return func


    def matyas(np):
    def func(p):
    x, y = p
    return 0.26 * (x**2 + y**2) - 0.48 * x * y

    return func


    def eggholder(np):
    def func(p):
    x, y = p
    return -(y + 47.) * jnp.sin(
    jnp.sqrt(jnp.abs(x / 2. + y + 47.))
    ) - x * jnp.sin(jnp.sqrt(jnp.abs(x - (y + 47.))))

    return func


    def hessp(primals, tangents):
    return jax.jvp(jax.grad(fun), (primals, ), (tangents, ))[1]


    fun = eggholder(jnp)
    x0 = jnp.ones(2) * 100.
    f0, g0 = jax.value_and_grad(fun)(x0)
    kwargs = {
    "absdelta": 0.,
    "resnorm": 0.,
    "trust_radius": 1.,
    "norm_ord": 1,
    }
    _cg_steihaug_subproblem(f0, g0, partial(hessp, x0), **kwargs)