# %% from functools import partial from typing import NamedTuple, Optional, Tuple, Union import jax from jax import lax from jax import numpy as jnp N_RESET = 20 class CGResults(NamedTuple): x: jnp.ndarray nit: Union[int, jnp.ndarray] nfev: Union[int, jnp.ndarray] # number of matrix-evaluations info: Union[int, jnp.ndarray] success: Union[bool, jnp.ndarray] # The following is code adapted from Nicholas Mancuso to work with pytrees class _QuadSubproblemResult(NamedTuple): step: jnp.ndarray hits_boundary: Union[bool, jnp.ndarray] pred_f: Union[float, jnp.ndarray] nit: Union[int, jnp.ndarray] nfev: Union[int, jnp.ndarray] njev: Union[int, jnp.ndarray] nhev: Union[int, jnp.ndarray] success: Union[bool, jnp.ndarray] class _CGSteihaugState(NamedTuple): z: jnp.ndarray r: jnp.ndarray d: jnp.ndarray step: jnp.ndarray energy: Union[None, float, jnp.ndarray] hits_boundary: Union[bool, jnp.ndarray] done: Union[bool, jnp.ndarray] nit: Union[int, jnp.ndarray] nhev: Union[int, jnp.ndarray] def second_order_approx( p: jnp.ndarray, cur_val: Union[float, jnp.ndarray], g: jnp.ndarray, hessp_at_xk, ) -> Union[float, jnp.ndarray]: return cur_val + jnp.vdot(g, p) + 0.5 * jnp.vdot(p, hessp_at_xk(p)) def get_boundaries_intersections( z: jnp.ndarray, d: jnp.ndarray, trust_radius: Union[float, jnp.ndarray] ): a = jnp.vdot(d, d) b = 2 * jnp.vdot(z, d) c = jnp.vdot(z, z) - trust_radius**2 sqrt_discriminant = jnp.sqrt(b * b - 4 * a * c) aux = b + jnp.copysign(sqrt_discriminant, b) ta = -aux / (2 * a) tb = -2 * c / aux ra = jnp.where(ta < tb, ta, tb) rb = jnp.where(ta < tb, tb, ta) return (ra, rb) def _cg_steihaug_subproblem( cur_val: Union[float, jnp.ndarray], g: jnp.ndarray, hessp_at_xk, *, trust_radius: Union[float, jnp.ndarray], tr_norm_ord: Union[None, int, float, jnp.ndarray] = None, resnorm: Optional[float], absdelta: Optional[float] = None, norm_ord: Union[None, int, float, jnp.ndarray] = None, miniter: Union[None, int] = None, maxiter: Union[None, int] = None, ) -> _QuadSubproblemResult: from jax.experimental.host_callback import call tr_norm_ord = jnp.inf if tr_norm_ord is None else tr_norm_ord # taken from JAX norm_ord = 2 if norm_ord is None else norm_ord # TODO: change to 1 maxiter_fallback = 20 * g.size # taken from SciPy's NewtonCG minimzer miniter = jnp.minimum( 6, maxiter if maxiter is not None else maxiter_fallback ) if miniter is None else miniter maxiter = jnp.maximum( jnp.minimum(200, maxiter_fallback), miniter ) if maxiter is None else maxiter common_dtp = g.dtype eps = 6. * jnp.finfo(common_dtp).eps # second-order Taylor series approximation at the current values, gradient, # and hessian soa = partial( second_order_approx, cur_val=cur_val, g=g, hessp_at_xk=hessp_at_xk ) # helpers for internal switches in the main CGSteihaug logic def noop( param: Tuple[_CGSteihaugState, Union[float, jnp.ndarray]] ) -> _CGSteihaugState: iterp, z_next = param return iterp def step1( param: Tuple[_CGSteihaugState, Union[float, jnp.ndarray]] ) -> _CGSteihaugState: iterp, z_next = param z, d, nhev = iterp.z, iterp.d, iterp.nhev ta, tb = get_boundaries_intersections(z, d, trust_radius) pa = z + ta * d pb = z + tb * d p_boundary = jnp.where(soa(pa) < soa(pb), pa, pb) return iterp._replace( step=p_boundary, nhev=nhev + 2, hits_boundary=True, done=True ) def step2( param: Tuple[_CGSteihaugState, Union[float, jnp.ndarray]] ) -> _CGSteihaugState: iterp, z_next = param z, d = iterp.z, iterp.d ta, tb = get_boundaries_intersections(z, d, trust_radius) p_boundary = z + tb * d return iterp._replace(step=p_boundary, hits_boundary=True, done=True) def step3( param: Tuple[_CGSteihaugState, Union[float, jnp.ndarray]] ) -> _CGSteihaugState: iterp, z_next = param return iterp._replace(step=z_next, hits_boundary=False, done=True) # initialize the step p_origin = jnp.zeros_like(g) # init the state for the first iteration z = p_origin r = g d = -r energy = 0. init_param = _CGSteihaugState( z=z, r=r, d=d, step=p_origin, energy=energy, hits_boundary=False, done=maxiter == 0, nit=0, nhev=0 ) import jax # Search for the min of the approximation of the objective function. def body_f(iterp: _CGSteihaugState) -> _CGSteihaugState: z, r, d = iterp.z, iterp.r, iterp.d energy, nit = iterp.energy, iterp.nit nit += 1 jax.debug.print("in body {nit} \\\\ 1 ::", nit=nit) Bd = hessp_at_xk(d) dBd = jnp.vdot(d, Bd) r_squared = jnp.vdot(r, r) alpha = r_squared / dBd z_next = z + alpha * d r_next = r + alpha * Bd r_next_squared = jnp.vdot(r_next, r_next) beta_next = r_next_squared / r_squared d_next = -r_next + beta_next * d jax.debug.print("in body {nit} \\\\ 2 ::", nit=nit) accept_z_next = nit >= maxiter jax.debug.print( "in body {nit} \\\\ 3 :: accept_z_next={accept_z_next}", nit=nit, accept_z_next=accept_z_next ) if norm_ord == 2: r_next_norm = jnp.sqrt(r_next_squared) else: r_next_norm = jnp.linalg.norm(r_next, ord=norm_ord) accept_z_next |= r_next_norm < resnorm # Relative to a plain CG, `z_next` is negative energy_next = jnp.vdot((r_next + g) / 2, z_next) energy_diff = energy - energy_next if absdelta is not None: neg_energy_eps = -eps * jnp.abs(energy) accept_z_next |= (energy_diff >= neg_energy_eps ) & (energy_diff < absdelta) & (nit >= miniter) jax.debug.print("in body {nit} \\\\ 4 ::", nit=nit) # include a junk switch to catch the case where none should be executed z_next_norm = jnp.linalg.norm(z_next, ord=tr_norm_ord) jax.debug.print("in body {nit} \\\\ 5 :: pre-index", nit=nit) index = jnp.argmax( jnp.array( [False, dBd <= 0, z_next_norm >= trust_radius, accept_z_next] ) ) jax.debug.print("in body {nit} \\\\ 6 :: pre-switch {index}", nit=nit, index=index) iterp = lax.switch(index, [noop, step1, step2, step3], (iterp, z_next)) jax.debug.print("in body {nit} \\\\ 7 :: post-switch", nit=nit) iterp = iterp._replace( z=z_next, r=r_next, d=d_next, energy=energy_next, nhev=iterp.nhev + 1, nit=nit ) return iterp def cond_f(iterp: _CGSteihaugState) -> bool: jax.debug.print( "cond_f={c} maxiter={maxiter}", c=~iterp.done, maxiter=maxiter ) return jnp.logical_not(iterp.done) # perform inner optimization to solve the constrained # quadratic subproblem using cg jax.debug.print("looped {result.done} {result}", result=init_param) result = lax.while_loop(cond_f, body_f, init_param) jax.debug.print("looped {result.done} {result}", result=result) pred_f = soa(result.step) result = _QuadSubproblemResult( step=result.step, hits_boundary=result.hits_boundary, pred_f=pred_f, nit=result.nit, nfev=0, njev=0, nhev=result.nhev + 1, success=True ) return result def rosenbrock(np): def func(x): return jnp.sum(100. * jnp.diff(x)**2 + (1. - x[:-1])**2) return func def himmelblau(np): def func(p): x, y = p return (x**2 + y - 11.)**2 + (x + y**2 - 7.)**2 return func def matyas(np): def func(p): x, y = p return 0.26 * (x**2 + y**2) - 0.48 * x * y return func def eggholder(np): def func(p): x, y = p return -(y + 47.) * jnp.sin( jnp.sqrt(jnp.abs(x / 2. + y + 47.)) ) - x * jnp.sin(jnp.sqrt(jnp.abs(x - (y + 47.)))) return func def hessp(primals, tangents): return jax.jvp(jax.grad(fun), (primals, ), (tangents, ))[1] fun = eggholder(jnp) x0 = jnp.ones(2) * 100. f0, g0 = jax.value_and_grad(fun)(x0) kwargs = { "absdelta": 0., "resnorm": 0., "trust_radius": 1., "norm_ord": 1, } _cg_steihaug_subproblem(f0, g0, partial(hessp, x0), **kwargs)