Skip to content

Instantly share code, notes, and snippets.

@MarisaKirisame
Created March 8, 2020 16:51
Show Gist options
  • Select an option

  • Save MarisaKirisame/09edade7b2ab2bef3bfb779904e3cfdb to your computer and use it in GitHub Desktop.

Select an option

Save MarisaKirisame/09edade7b2ab2bef3bfb779904e3cfdb to your computer and use it in GitHub Desktop.

Revisions

  1. MarisaKirisame renamed this gist Mar 8, 2020. 1 changed file with 0 additions and 0 deletions.
    File renamed without changes.
  2. MarisaKirisame created this gist Mar 8, 2020.
    143 changes: 143 additions & 0 deletions AD
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,143 @@
    import math

    def sin(x):
    if isinstance(x, Dual):
    return Dual(sin(x.x), cos(x.x) * x.dx)
    return math.sin(x)

    def cos(x):
    if isinstance(x, Dual):
    return Dual(cos(x.x), -1 * sin(x.x) * x.dx)
    return math.cos(x)

    class Dual:
    def __init__(self, x, dx):
    self.x = x
    self.dx = dx

    def __add__(self, r):
    assert isinstance(r, Dual)
    return Dual(self.x + r.x, self.dx + r.dx)

    def __mul__(self, r):
    if isinstance(r, Dual):
    return Dual(self.x * r.x, self.x * r.dx + r.x * self.dx)
    assert isinstance(r, float)
    return Dual(self.x * r, r * self.dx)

    def __rmul__(self, r):
    return self * r

    def __repr__(self):
    return repr((self.x, self.dx))

    class Raw:
    def __mul__(self, r):
    assert isinstance(r, float)
    return Raw()

    def __rmul__(self, r):
    return self * r

    def __add__(self, r):
    assert isinstance(r, Raw)
    return Raw()

    def __repr__(self):
    return repr(())

    def f(x, y):
    return sin(x) * y, sin(y) + x

    print("raw:")
    print(f(Dual(1.0, Raw()), Dual(2.0, Raw())))

    print("forward mode:")
    print(f(Dual(1.0, 1.0), Dual(2.0, 0.0)))
    print(f(Dual(1.0, 0.0), Dual(2.0, 1.0)))

    class Ref:
    def __init__(self, v):
    self.v = v

    class WithBP:
    def __init__(self, rdx, bp):
    self.rdx = rdx
    self.bp = bp

    def __mul__(self, rhs):
    assert isinstance(rhs, float)
    r = Ref(0.0)
    bpv = self.bp.v
    def new_bp():
    self.rdx.v = self.rdx.v + r.v * rhs
    bpv()
    self.bp.v = new_bp
    return WithBP(r, self.bp)

    def __rmul__(self, rhs):
    return self * rhs

    def __add__(self, rhs):
    assert isinstance(rhs, WithBP)
    r = Ref(0.0)
    bpv = self.bp.v
    def new_bp():
    self.rdx.v = self.rdx.v + r.v
    rhs.rdx.v = rhs.rdx.v + r.v
    bpv()
    self.bp.v = new_bp
    return WithBP(r, self.bp)

    print("reverse mode:")
    bp = Ref(lambda: ())
    x = WithBP(Ref(0.0), bp)
    y = WithBP(Ref(0.0), bp)

    a, b = f(Dual(1.0, x), Dual(2.0, y))
    a.dx.rdx.v = 1.0
    a.dx.bp.v()
    print((x.rdx.v, y.rdx.v))

    bp.v = lambda: ()
    x = WithBP(Ref(0.0), bp)
    y = WithBP(Ref(0.0), bp)

    a, b = f(Dual(1.0, x), Dual(2.0, y))
    b.dx.rdx.v = 1.0
    bp.v()
    print((x.rdx.v, y.rdx.v))

    class Batched:
    def __init__(self, *l):
    self.l = l

    def __mul__(self, rhs):
    assert isinstance(rhs, float)
    return Batched(*[x * rhs for x in self.l])

    def __add__(self, rhs):
    assert isinstance(rhs, Batched)
    assert len(self.l) == len(rhs.l)
    return Batched(*[self.l[i] + rhs.l[i] for i in range(len(self.l))])

    def __rmul__(self, rhs):
    return self * rhs

    def __repr__(self):
    return repr(self.l)

    print("batched forward mode:")
    print(f(Dual(1.0, Batched(1.0, 0.0)), Dual(2.0, Batched(0.0, 1.0))))

    print("batched reverse mode:")
    bp = Ref(lambda: ())
    ax = WithBP(Ref(0.0), bp)
    bx = WithBP(Ref(0.0), bp)
    ay = WithBP(Ref(0.0), bp)
    by = WithBP(Ref(0.0), bp)
    a, b = f(Dual(1.0, Batched(ax, bx)), Dual(2.0, Batched(ay, by)))
    a.dx.l[0].rdx.v = 1.0
    b.dx.l[1].rdx.v = 1.0
    bp.v()
    print((ax.rdx.v, bx.rdx.v, ay.rdx.v, by.rdx.v))