Created
March 16, 2023 18:31
-
-
Save Edenhofer/ece9a2e3e8c67721dbdd706b3966f04c to your computer and use it in GitHub Desktop.
Revisions
-
Edenhofer created this gist
Mar 16, 2023 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,299 @@ # %% from functools import partial from typing import NamedTuple, Optional, Tuple, Union import jax from jax import lax from jax import numpy as jnp N_RESET = 20 class CGResults(NamedTuple): x: jnp.ndarray nit: Union[int, jnp.ndarray] nfev: Union[int, jnp.ndarray] # number of matrix-evaluations info: Union[int, jnp.ndarray] success: Union[bool, jnp.ndarray] # The following is code adapted from Nicholas Mancuso to work with pytrees class _QuadSubproblemResult(NamedTuple): step: jnp.ndarray hits_boundary: Union[bool, jnp.ndarray] pred_f: Union[float, jnp.ndarray] nit: Union[int, jnp.ndarray] nfev: Union[int, jnp.ndarray] njev: Union[int, jnp.ndarray] nhev: Union[int, jnp.ndarray] success: Union[bool, jnp.ndarray] class _CGSteihaugState(NamedTuple): z: jnp.ndarray r: jnp.ndarray d: jnp.ndarray step: jnp.ndarray energy: Union[None, float, jnp.ndarray] hits_boundary: Union[bool, jnp.ndarray] done: Union[bool, jnp.ndarray] nit: Union[int, jnp.ndarray] nhev: Union[int, jnp.ndarray] def second_order_approx( p: jnp.ndarray, cur_val: Union[float, jnp.ndarray], g: jnp.ndarray, hessp_at_xk, ) -> Union[float, jnp.ndarray]: return cur_val + jnp.vdot(g, p) + 0.5 * jnp.vdot(p, hessp_at_xk(p)) def get_boundaries_intersections( z: jnp.ndarray, d: jnp.ndarray, trust_radius: Union[float, jnp.ndarray] ): a = jnp.vdot(d, d) b = 2 * jnp.vdot(z, d) c = jnp.vdot(z, z) - trust_radius**2 sqrt_discriminant = jnp.sqrt(b * b - 4 * a * c) aux = b + jnp.copysign(sqrt_discriminant, b) ta = -aux / (2 * a) tb = -2 * c / aux ra = jnp.where(ta < tb, ta, tb) rb = jnp.where(ta < tb, tb, ta) return (ra, rb) def _cg_steihaug_subproblem( cur_val: Union[float, jnp.ndarray], g: jnp.ndarray, hessp_at_xk, *, trust_radius: Union[float, jnp.ndarray], tr_norm_ord: Union[None, int, float, jnp.ndarray] = None, resnorm: Optional[float], absdelta: Optional[float] = None, norm_ord: Union[None, int, float, jnp.ndarray] = None, miniter: Union[None, int] = None, maxiter: Union[None, int] = None, ) -> _QuadSubproblemResult: from jax.experimental.host_callback import call tr_norm_ord = jnp.inf if tr_norm_ord is None else tr_norm_ord # taken from JAX norm_ord = 2 if norm_ord is None else norm_ord # TODO: change to 1 maxiter_fallback = 20 * g.size # taken from SciPy's NewtonCG minimzer miniter = jnp.minimum( 6, maxiter if maxiter is not None else maxiter_fallback ) if miniter is None else miniter maxiter = jnp.maximum( jnp.minimum(200, maxiter_fallback), miniter ) if maxiter is None else maxiter common_dtp = g.dtype eps = 6. * jnp.finfo(common_dtp).eps # second-order Taylor series approximation at the current values, gradient, # and hessian soa = partial( second_order_approx, cur_val=cur_val, g=g, hessp_at_xk=hessp_at_xk ) # helpers for internal switches in the main CGSteihaug logic def noop( param: Tuple[_CGSteihaugState, Union[float, jnp.ndarray]] ) -> _CGSteihaugState: iterp, z_next = param return iterp def step1( param: Tuple[_CGSteihaugState, Union[float, jnp.ndarray]] ) -> _CGSteihaugState: iterp, z_next = param z, d, nhev = iterp.z, iterp.d, iterp.nhev ta, tb = get_boundaries_intersections(z, d, trust_radius) pa = z + ta * d pb = z + tb * d p_boundary = jnp.where(soa(pa) < soa(pb), pa, pb) return iterp._replace( step=p_boundary, nhev=nhev + 2, hits_boundary=True, done=True ) def step2( param: Tuple[_CGSteihaugState, Union[float, jnp.ndarray]] ) -> _CGSteihaugState: iterp, z_next = param z, d = iterp.z, iterp.d ta, tb = get_boundaries_intersections(z, d, trust_radius) p_boundary = z + tb * d return iterp._replace(step=p_boundary, hits_boundary=True, done=True) def step3( param: Tuple[_CGSteihaugState, Union[float, jnp.ndarray]] ) -> _CGSteihaugState: iterp, z_next = param return iterp._replace(step=z_next, hits_boundary=False, done=True) # initialize the step p_origin = jnp.zeros_like(g) # init the state for the first iteration z = p_origin r = g d = -r energy = 0. init_param = _CGSteihaugState( z=z, r=r, d=d, step=p_origin, energy=energy, hits_boundary=False, done=maxiter == 0, nit=0, nhev=0 ) import jax # Search for the min of the approximation of the objective function. def body_f(iterp: _CGSteihaugState) -> _CGSteihaugState: z, r, d = iterp.z, iterp.r, iterp.d energy, nit = iterp.energy, iterp.nit nit += 1 jax.debug.print("in body {nit} \\\\ 1 ::", nit=nit) Bd = hessp_at_xk(d) dBd = jnp.vdot(d, Bd) r_squared = jnp.vdot(r, r) alpha = r_squared / dBd z_next = z + alpha * d r_next = r + alpha * Bd r_next_squared = jnp.vdot(r_next, r_next) beta_next = r_next_squared / r_squared d_next = -r_next + beta_next * d jax.debug.print("in body {nit} \\\\ 2 ::", nit=nit) accept_z_next = nit >= maxiter jax.debug.print( "in body {nit} \\\\ 3 :: accept_z_next={accept_z_next}", nit=nit, accept_z_next=accept_z_next ) if norm_ord == 2: r_next_norm = jnp.sqrt(r_next_squared) else: r_next_norm = jnp.linalg.norm(r_next, ord=norm_ord) accept_z_next |= r_next_norm < resnorm # Relative to a plain CG, `z_next` is negative energy_next = jnp.vdot((r_next + g) / 2, z_next) energy_diff = energy - energy_next if absdelta is not None: neg_energy_eps = -eps * jnp.abs(energy) accept_z_next |= (energy_diff >= neg_energy_eps ) & (energy_diff < absdelta) & (nit >= miniter) jax.debug.print("in body {nit} \\\\ 4 ::", nit=nit) # include a junk switch to catch the case where none should be executed z_next_norm = jnp.linalg.norm(z_next, ord=tr_norm_ord) jax.debug.print("in body {nit} \\\\ 5 :: pre-index", nit=nit) index = jnp.argmax( jnp.array( [False, dBd <= 0, z_next_norm >= trust_radius, accept_z_next] ) ) jax.debug.print("in body {nit} \\\\ 6 :: pre-switch {index}", nit=nit, index=index) iterp = lax.switch(index, [noop, step1, step2, step3], (iterp, z_next)) jax.debug.print("in body {nit} \\\\ 7 :: post-switch", nit=nit) iterp = iterp._replace( z=z_next, r=r_next, d=d_next, energy=energy_next, nhev=iterp.nhev + 1, nit=nit ) return iterp def cond_f(iterp: _CGSteihaugState) -> bool: jax.debug.print( "cond_f={c} maxiter={maxiter}", c=~iterp.done, maxiter=maxiter ) return jnp.logical_not(iterp.done) # perform inner optimization to solve the constrained # quadratic subproblem using cg jax.debug.print("looped {result.done} {result}", result=init_param) result = lax.while_loop(cond_f, body_f, init_param) jax.debug.print("looped {result.done} {result}", result=result) pred_f = soa(result.step) result = _QuadSubproblemResult( step=result.step, hits_boundary=result.hits_boundary, pred_f=pred_f, nit=result.nit, nfev=0, njev=0, nhev=result.nhev + 1, success=True ) return result def rosenbrock(np): def func(x): return jnp.sum(100. * jnp.diff(x)**2 + (1. - x[:-1])**2) return func def himmelblau(np): def func(p): x, y = p return (x**2 + y - 11.)**2 + (x + y**2 - 7.)**2 return func def matyas(np): def func(p): x, y = p return 0.26 * (x**2 + y**2) - 0.48 * x * y return func def eggholder(np): def func(p): x, y = p return -(y + 47.) * jnp.sin( jnp.sqrt(jnp.abs(x / 2. + y + 47.)) ) - x * jnp.sin(jnp.sqrt(jnp.abs(x - (y + 47.)))) return func def hessp(primals, tangents): return jax.jvp(jax.grad(fun), (primals, ), (tangents, ))[1] fun = eggholder(jnp) x0 = jnp.ones(2) * 100. f0, g0 = jax.value_and_grad(fun)(x0) kwargs = { "absdelta": 0., "resnorm": 0., "trust_radius": 1., "norm_ord": 1, } _cg_steihaug_subproblem(f0, g0, partial(hessp, x0), **kwargs)