import math def sin(x): if isinstance(x, Dual): return Dual(sin(x.x), cos(x.x) * x.dx) return math.sin(x) def cos(x): if isinstance(x, Dual): return Dual(cos(x.x), -1 * sin(x.x) * x.dx) return math.cos(x) class Dual: def __init__(self, x, dx): self.x = x self.dx = dx def __add__(self, r): assert isinstance(r, Dual) return Dual(self.x + r.x, self.dx + r.dx) def __mul__(self, r): if isinstance(r, Dual): return Dual(self.x * r.x, self.x * r.dx + r.x * self.dx) assert isinstance(r, float) return Dual(self.x * r, r * self.dx) def __rmul__(self, r): return self * r def __repr__(self): return repr((self.x, self.dx)) class Raw: def __mul__(self, r): assert isinstance(r, float) return Raw() def __rmul__(self, r): return self * r def __add__(self, r): assert isinstance(r, Raw) return Raw() def __repr__(self): return repr(()) def f(x, y): return sin(x) * y, sin(y) + x print("raw:") print(f(Dual(1.0, Raw()), Dual(2.0, Raw()))) print("forward mode:") print(f(Dual(1.0, 1.0), Dual(2.0, 0.0))) print(f(Dual(1.0, 0.0), Dual(2.0, 1.0))) class Ref: def __init__(self, v): self.v = v class WithBP: def __init__(self, rdx, bp): self.rdx = rdx self.bp = bp def __mul__(self, rhs): assert isinstance(rhs, float) r = Ref(0.0) bpv = self.bp.v def new_bp(): self.rdx.v = self.rdx.v + r.v * rhs bpv() self.bp.v = new_bp return WithBP(r, self.bp) def __rmul__(self, rhs): return self * rhs def __add__(self, rhs): assert isinstance(rhs, WithBP) r = Ref(0.0) bpv = self.bp.v def new_bp(): self.rdx.v = self.rdx.v + r.v rhs.rdx.v = rhs.rdx.v + r.v bpv() self.bp.v = new_bp return WithBP(r, self.bp) print("reverse mode:") bp = Ref(lambda: ()) x = WithBP(Ref(0.0), bp) y = WithBP(Ref(0.0), bp) a, b = f(Dual(1.0, x), Dual(2.0, y)) a.dx.rdx.v = 1.0 a.dx.bp.v() print((x.rdx.v, y.rdx.v)) bp.v = lambda: () x = WithBP(Ref(0.0), bp) y = WithBP(Ref(0.0), bp) a, b = f(Dual(1.0, x), Dual(2.0, y)) b.dx.rdx.v = 1.0 bp.v() print((x.rdx.v, y.rdx.v)) class Batched: def __init__(self, *l): self.l = l def __mul__(self, rhs): assert isinstance(rhs, float) return Batched(*[x * rhs for x in self.l]) def __add__(self, rhs): assert isinstance(rhs, Batched) assert len(self.l) == len(rhs.l) return Batched(*[self.l[i] + rhs.l[i] for i in range(len(self.l))]) def __rmul__(self, rhs): return self * rhs def __repr__(self): return repr(self.l) print("batched forward mode:") print(f(Dual(1.0, Batched(1.0, 0.0)), Dual(2.0, Batched(0.0, 1.0)))) print("batched reverse mode:") bp = Ref(lambda: ()) ax = WithBP(Ref(0.0), bp) bx = WithBP(Ref(0.0), bp) ay = WithBP(Ref(0.0), bp) by = WithBP(Ref(0.0), bp) a, b = f(Dual(1.0, Batched(ax, bx)), Dual(2.0, Batched(ay, by))) a.dx.l[0].rdx.v = 1.0 b.dx.l[1].rdx.v = 1.0 bp.v() print((ax.rdx.v, bx.rdx.v, ay.rdx.v, by.rdx.v))