Skip to content

Instantly share code, notes, and snippets.

@metric-space
Last active April 1, 2023 18:51
Show Gist options
  • Save metric-space/f95eb8f2ead9c93f0c76ff52be490c40 to your computer and use it in GitHub Desktop.
Save metric-space/f95eb8f2ead9c93f0c76ff52be490c40 to your computer and use it in GitHub Desktop.

Revisions

  1. metric-space revised this gist Apr 1, 2023. 1 changed file with 9 additions and 21 deletions.
    30 changes: 9 additions & 21 deletions jaxkarpathy.py
    Original file line number Diff line number Diff line change
    @@ -1,5 +1,5 @@
    import jax.numpy as jnp
    from jax import jit, vmap, value_and_grad
    from jax import jit, vmap, grad, value_and_grad
    from jax import random
    import jax

    @@ -9,7 +9,7 @@
    # hyperparameters
    hidden_size = 100
    seq_length = 25
    learning_rate = 1e-1
    learning_rate = 1e-3


    def initialize_network_params(hidden_size, vocab_size, key):
    @@ -61,6 +61,11 @@ def sample(params, hprev, seed_ix, n, key):

    return ixes, key_

    @jit
    def update(params, inputs, targets, hprev):
    (loss_, hprev),grads = value_and_grad(loss, has_aux=True)(params, inputs, targets, hprev)
    return [jnp.clip(w - learning_rate * dw, -5, 5) for (w, dw) in zip(params, grads)], loss_, hprev


    with open('input.txt', 'r') as f:
    data = f.read()
    @@ -78,12 +83,6 @@ def sample(params, hprev, seed_ix, n, key):

    Wxh, Whh, Why, bh, by = params

    mWxh, mWhh, mWhy = jnp.zeros_like(Wxh), jnp.zeros_like(Whh), jnp.zeros_like(Why)
    mbh, mby = jnp.zeros_like(bh), jnp.zeros_like(by) # memory variables for Adagrad
    smooth_loss = -jnp.log(1.0/vocab_size)*seq_length

    params = list(params)

    while True:
    p = p + seq_length
    if ((p + seq_length + 1) >= len(data)) or n == 0:
    @@ -99,19 +98,8 @@ def sample(params, hprev, seed_ix, n, key):
    txt = ''.join(ix_to_char[ix] for ix in sample_ix)
    print('----\n %s \n----' % (txt, ))

    (loss_, hprev), grads = value_and_grad(loss, has_aux=True)(params, inputs, targets, hprev)
    dWxh, dWhh, dWhy, dbh, dby = grads
    for dparam in [dWxh, dWhh, dWhy, dbh, dby]:
    dparam = jnp.clip(dparam, -5, 5)
    smooth_loss = smooth_loss * 0.999 + loss_ * 0.001

    if n % 100 == 0: print('iter %d, loss: %f' % (n, smooth_loss))
    params,loss_, hprev = update(params, inputs, targets, hprev)

    mem_vars = [mWxh, mWhh, mWhy, mbh, mby]
    for idx, (param, dparam, mem) in enumerate(zip(params, grads, mem_vars)):
    mem_updated = mem + dparam * dparam
    params[idx] = param - learning_rate * dparam / jnp.sqrt(mem_updated + 1e-8) # Adagrad update
    mem_vars[idx] = mem_updated
    if n % 100 == 0: print('iter %d, loss: %f' % (n, loss_))

    mWxh, mWhh, mWhy, mbh, mby = mem_vars
    n += 1
  2. metric-space created this gist Mar 31, 2023.
    117 changes: 117 additions & 0 deletions jaxkarpathy.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,117 @@
    import jax.numpy as jnp
    from jax import jit, vmap, value_and_grad
    from jax import random
    import jax

    SEED = 42234
    key = random.PRNGKey(SEED)

    # hyperparameters
    hidden_size = 100
    seq_length = 25
    learning_rate = 1e-1


    def initialize_network_params(hidden_size, vocab_size, key):
    key, *subkey = random.split(key, num=4)
    # model parameters
    # :: Matrix R[hidden_size] C[vocab_size]
    Wxh = random.normal(subkey[0], (hidden_size, vocab_size)) * 0.01 # input to hidden
    Whh = random.normal(subkey[1], (hidden_size, hidden_size)) * 0.01 # hidden to hidden
    Why = random.normal(subkey[2], (vocab_size, hidden_size)) * 0.01 # hidden to output
    bh = jnp.zeros((hidden_size, 1)) # hidden bias
    by = jnp.zeros((vocab_size, 1)) # output bias
    return Wxh, Whh, Why, bh, by

    @jit
    def loss(params,inputs, targets, hprev):
    loss = 0
    Wxh, Whh, Why, bh, by = params
    hprev = hprev.copy()
    for t in range(len(inputs)):
    x = jnp.zeros(by.shape)
    x = x.at[inputs[t]].set(1)
    y = targets[t]
    hprev = jnp.tanh(Wxh @ x + Whh @ hprev + bh)
    y_pred = Why @ hprev + by
    log_prob = jax.nn.log_softmax(y_pred.flatten()) # this (log_softmax) affects stability
    loss += -log_prob[y]

    return loss, hprev


    def sample(params, hprev, seed_ix, n, key):
    Wxh, Whh, Why, bh, by = params
    x = jnp.zeros((vocab_size,1))
    x = x.at[seed_ix].set(1)
    ixes = []
    key_ = key
    h = hprev.copy()

    for i in range(n):
    key_, subkey = random.split(key_)
    h = jnp.tanh(Wxh @ x + Whh @ h + bh)
    y_pred = Why @ h + by
    p = jax.nn.softmax(y_pred.flatten())
    ix = jax.random.choice(subkey, jnp.arange(vocab_size), p=p, replace=False)
    ix = int(ix)
    ixes.append(ix)
    x = jnp.zeros((vocab_size,1))
    x = x.at[ix].set(1)

    return ixes, key_


    with open('input.txt', 'r') as f:
    data = f.read()
    chars = list(set(data))
    char_to_ix = { ch:i for i,ch in enumerate(chars)}
    ix_to_char = { i:ch for i, ch in enumerate(chars)}

    vocab_size = len(chars)

    params = initialize_network_params(hidden_size, vocab_size, key)

    p = -seq_length
    n = 0
    hprev = jnp.zeros((hidden_size,1))

    Wxh, Whh, Why, bh, by = params

    mWxh, mWhh, mWhy = jnp.zeros_like(Wxh), jnp.zeros_like(Whh), jnp.zeros_like(Why)
    mbh, mby = jnp.zeros_like(bh), jnp.zeros_like(by) # memory variables for Adagrad
    smooth_loss = -jnp.log(1.0/vocab_size)*seq_length

    params = list(params)

    while True:
    p = p + seq_length
    if ((p + seq_length + 1) >= len(data)) or n == 0:
    p = 0
    hprev = jnp.zeros((hidden_size,1))

    inputs = [char_to_ix[ch] for ch in data[p:p+seq_length]]
    targets = [char_to_ix[ch] for ch in data[p+1:p+seq_length+1]]

    if n % 1000 == 0:
    key, subkey = random.split(key)
    sample_ix, key = sample(params, hprev, inputs[0], 200, subkey)
    txt = ''.join(ix_to_char[ix] for ix in sample_ix)
    print('----\n %s \n----' % (txt, ))

    (loss_, hprev), grads = value_and_grad(loss, has_aux=True)(params, inputs, targets, hprev)
    dWxh, dWhh, dWhy, dbh, dby = grads
    for dparam in [dWxh, dWhh, dWhy, dbh, dby]:
    dparam = jnp.clip(dparam, -5, 5)
    smooth_loss = smooth_loss * 0.999 + loss_ * 0.001

    if n % 100 == 0: print('iter %d, loss: %f' % (n, smooth_loss))

    mem_vars = [mWxh, mWhh, mWhy, mbh, mby]
    for idx, (param, dparam, mem) in enumerate(zip(params, grads, mem_vars)):
    mem_updated = mem + dparam * dparam
    params[idx] = param - learning_rate * dparam / jnp.sqrt(mem_updated + 1e-8) # Adagrad update
    mem_vars[idx] = mem_updated

    mWxh, mWhh, mWhy, mbh, mby = mem_vars
    n += 1