Last active
August 13, 2019 22:41
-
-
Save pierrelux/059bfef496354cd8fa0ff4557db3b58b to your computer and use it in GitHub Desktop.
Revisions
-
pierrelux revised this gist
Aug 13, 2019 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -10,7 +10,7 @@ def _f(arg): return _f def mixed_jvp(f, order, primals, tangents): frozen_func = freeze(grad(f, order[0]), argnum=order[1], val=primals[order[0]]) return jvp(frozen_func, (primals[order[1]],), tangents) mixed_jvp(f, order=(0,1), primals=(2., 3.), tangents=(1.,)) -
pierrelux revised this gist
Aug 13, 2019 . 1 changed file with 3 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,5 +1,8 @@ from jax import jvp, grad def f(x,y): return x + y**2 def freeze(f, argnum, val): def _f(arg): args = [val, arg] if argnum == 0 else [arg, val] -
pierrelux created this gist
Aug 13, 2019 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,13 @@ from jax import jvp, grad def freeze(f, argnum, val): def _f(arg): args = [val, arg] if argnum == 0 else [arg, val] return f(*args) return _f def mixed_jvp(f, order, primals, tangents): frozen_func = freeze(grad(f, order[0]), argntum=order[1], val=primals[order[0]]) return jvp(frozen_func, (primals[order[1]],), tangents) mixed_jvp(f, order=(0,1), primals=(2., 3.), tangents=(1.,))