Skip to content

Instantly share code, notes, and snippets.

@pierrelux
Last active August 13, 2019 22:41
Show Gist options
  • Select an option

  • Save pierrelux/059bfef496354cd8fa0ff4557db3b58b to your computer and use it in GitHub Desktop.

Select an option

Save pierrelux/059bfef496354cd8fa0ff4557db3b58b to your computer and use it in GitHub Desktop.

Revisions

  1. pierrelux revised this gist Aug 13, 2019. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion mixed_jacobian.py
    Original file line number Diff line number Diff line change
    @@ -10,7 +10,7 @@ def _f(arg):
    return _f

    def mixed_jvp(f, order, primals, tangents):
    frozen_func = freeze(grad(f, order[0]), argntum=order[1], val=primals[order[0]])
    frozen_func = freeze(grad(f, order[0]), argnum=order[1], val=primals[order[0]])
    return jvp(frozen_func, (primals[order[1]],), tangents)

    mixed_jvp(f, order=(0,1), primals=(2., 3.), tangents=(1.,))
  2. pierrelux revised this gist Aug 13, 2019. 1 changed file with 3 additions and 0 deletions.
    3 changes: 3 additions & 0 deletions mixed_jacobian.py
    Original file line number Diff line number Diff line change
    @@ -1,5 +1,8 @@
    from jax import jvp, grad

    def f(x,y):
    return x + y**2

    def freeze(f, argnum, val):
    def _f(arg):
    args = [val, arg] if argnum == 0 else [arg, val]
  3. pierrelux created this gist Aug 13, 2019.
    13 changes: 13 additions & 0 deletions mixed_jacobian.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,13 @@
    from jax import jvp, grad

    def freeze(f, argnum, val):
    def _f(arg):
    args = [val, arg] if argnum == 0 else [arg, val]
    return f(*args)
    return _f

    def mixed_jvp(f, order, primals, tangents):
    frozen_func = freeze(grad(f, order[0]), argntum=order[1], val=primals[order[0]])
    return jvp(frozen_func, (primals[order[1]],), tangents)

    mixed_jvp(f, order=(0,1), primals=(2., 3.), tangents=(1.,))