Last active
April 1, 2023 18:51
-
-
Save metric-space/f95eb8f2ead9c93f0c76ff52be490c40 to your computer and use it in GitHub Desktop.
Revisions
-
metric-space revised this gist
Apr 1, 2023 . 1 changed file with 9 additions and 21 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,5 +1,5 @@ import jax.numpy as jnp from jax import jit, vmap, grad, value_and_grad from jax import random import jax @@ -9,7 +9,7 @@ # hyperparameters hidden_size = 100 seq_length = 25 learning_rate = 1e-3 def initialize_network_params(hidden_size, vocab_size, key): @@ -61,6 +61,11 @@ def sample(params, hprev, seed_ix, n, key): return ixes, key_ @jit def update(params, inputs, targets, hprev): (loss_, hprev),grads = value_and_grad(loss, has_aux=True)(params, inputs, targets, hprev) return [jnp.clip(w - learning_rate * dw, -5, 5) for (w, dw) in zip(params, grads)], loss_, hprev with open('input.txt', 'r') as f: data = f.read() @@ -78,12 +83,6 @@ def sample(params, hprev, seed_ix, n, key): Wxh, Whh, Why, bh, by = params while True: p = p + seq_length if ((p + seq_length + 1) >= len(data)) or n == 0: @@ -99,19 +98,8 @@ def sample(params, hprev, seed_ix, n, key): txt = ''.join(ix_to_char[ix] for ix in sample_ix) print('----\n %s \n----' % (txt, )) params,loss_, hprev = update(params, inputs, targets, hprev) if n % 100 == 0: print('iter %d, loss: %f' % (n, loss_)) n += 1 -
metric-space created this gist
Mar 31, 2023 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,117 @@ import jax.numpy as jnp from jax import jit, vmap, value_and_grad from jax import random import jax SEED = 42234 key = random.PRNGKey(SEED) # hyperparameters hidden_size = 100 seq_length = 25 learning_rate = 1e-1 def initialize_network_params(hidden_size, vocab_size, key): key, *subkey = random.split(key, num=4) # model parameters # :: Matrix R[hidden_size] C[vocab_size] Wxh = random.normal(subkey[0], (hidden_size, vocab_size)) * 0.01 # input to hidden Whh = random.normal(subkey[1], (hidden_size, hidden_size)) * 0.01 # hidden to hidden Why = random.normal(subkey[2], (vocab_size, hidden_size)) * 0.01 # hidden to output bh = jnp.zeros((hidden_size, 1)) # hidden bias by = jnp.zeros((vocab_size, 1)) # output bias return Wxh, Whh, Why, bh, by @jit def loss(params,inputs, targets, hprev): loss = 0 Wxh, Whh, Why, bh, by = params hprev = hprev.copy() for t in range(len(inputs)): x = jnp.zeros(by.shape) x = x.at[inputs[t]].set(1) y = targets[t] hprev = jnp.tanh(Wxh @ x + Whh @ hprev + bh) y_pred = Why @ hprev + by log_prob = jax.nn.log_softmax(y_pred.flatten()) # this (log_softmax) affects stability loss += -log_prob[y] return loss, hprev def sample(params, hprev, seed_ix, n, key): Wxh, Whh, Why, bh, by = params x = jnp.zeros((vocab_size,1)) x = x.at[seed_ix].set(1) ixes = [] key_ = key h = hprev.copy() for i in range(n): key_, subkey = random.split(key_) h = jnp.tanh(Wxh @ x + Whh @ h + bh) y_pred = Why @ h + by p = jax.nn.softmax(y_pred.flatten()) ix = jax.random.choice(subkey, jnp.arange(vocab_size), p=p, replace=False) ix = int(ix) ixes.append(ix) x = jnp.zeros((vocab_size,1)) x = x.at[ix].set(1) return ixes, key_ with open('input.txt', 'r') as f: data = f.read() chars = list(set(data)) char_to_ix = { ch:i for i,ch in enumerate(chars)} ix_to_char = { i:ch for i, ch in enumerate(chars)} vocab_size = len(chars) params = initialize_network_params(hidden_size, vocab_size, key) p = -seq_length n = 0 hprev = jnp.zeros((hidden_size,1)) Wxh, Whh, Why, bh, by = params mWxh, mWhh, mWhy = jnp.zeros_like(Wxh), jnp.zeros_like(Whh), jnp.zeros_like(Why) mbh, mby = jnp.zeros_like(bh), jnp.zeros_like(by) # memory variables for Adagrad smooth_loss = -jnp.log(1.0/vocab_size)*seq_length params = list(params) while True: p = p + seq_length if ((p + seq_length + 1) >= len(data)) or n == 0: p = 0 hprev = jnp.zeros((hidden_size,1)) inputs = [char_to_ix[ch] for ch in data[p:p+seq_length]] targets = [char_to_ix[ch] for ch in data[p+1:p+seq_length+1]] if n % 1000 == 0: key, subkey = random.split(key) sample_ix, key = sample(params, hprev, inputs[0], 200, subkey) txt = ''.join(ix_to_char[ix] for ix in sample_ix) print('----\n %s \n----' % (txt, )) (loss_, hprev), grads = value_and_grad(loss, has_aux=True)(params, inputs, targets, hprev) dWxh, dWhh, dWhy, dbh, dby = grads for dparam in [dWxh, dWhh, dWhy, dbh, dby]: dparam = jnp.clip(dparam, -5, 5) smooth_loss = smooth_loss * 0.999 + loss_ * 0.001 if n % 100 == 0: print('iter %d, loss: %f' % (n, smooth_loss)) mem_vars = [mWxh, mWhh, mWhy, mbh, mby] for idx, (param, dparam, mem) in enumerate(zip(params, grads, mem_vars)): mem_updated = mem + dparam * dparam params[idx] = param - learning_rate * dparam / jnp.sqrt(mem_updated + 1e-8) # Adagrad update mem_vars[idx] = mem_updated mWxh, mWhh, mWhy, mbh, mby = mem_vars n += 1