Skip to content

Instantly share code, notes, and snippets.

@enhuiz
Created February 24, 2023 12:13
Show Gist options
  • Select an option

  • Save enhuiz/335fd95c521ae11a3a8eb5c33a77e231 to your computer and use it in GitHub Desktop.

Select an option

Save enhuiz/335fd95c521ae11a3a8eb5c33a77e231 to your computer and use it in GitHub Desktop.

Revisions

  1. enhuiz created this gist Feb 24, 2023.
    30 changes: 30 additions & 0 deletions cool_adam.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,30 @@
    import matplotlib.pyplot as plt
    import torch
    from celluloid import Camera
    from torch import nn

    linear = nn.Linear(2, 1, bias=False)

    camera = Camera(plt.figure())

    optimizer = torch.optim.Adam(linear.parameters(), 1)

    x = torch.randn(1, 2)
    optimizer.zero_grad()
    y = linear(x)
    y.sum().backward()
    optimizer.step()

    plt.scatter(*linear.weight.tolist()[0])
    camera.snap()

    for _ in range(10):
    optimizer.zero_grad()
    optimizer.step()
    plt.scatter(*linear.weight.tolist()[0], color='blue')
    plt.xlim(-5, 5)
    print(linear.weight)
    camera.snap()

    animation = camera.animate()
    animation.save("test.mp4")