#!/usr/bin/env python # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. # Licensed under the Apache License, Version 2.0 # Modifications copyright Yash Bonde (C) 2021 Nimblebox.ai, Inc. # This file is peak Google! <3 # How far can you push Python before it's just too hard? from typing import Any, Dict, Iterable, List, Tuple, Optional import random import itertools import numpy as np from time import time from tqdm import tqdm import jax import haiku as hk import jax.numpy as jnp import logging logger = logging.getLogger() logger.setLevel(logging.INFO) ################################################################################ # Transformer # =========== # Just a causal attention transformer model (GPT) # Data set is a string sequence sampled with overlaps ################################################################################ class CausalSelfAttention(hk.MultiHeadAttention): def __call__(self, q, k = None, v = None, mask = None) -> jnp.ndarray: if q.ndim != 3: raise ValueError('Expect queries of shape [B, T, D].') seq_len = q.shape[1] causal_mask = np.tril(np.ones((seq_len, seq_len))) mask = mask * causal_mask if mask is not None else causal_mask return super().__call__( q, k if k is not None else q, v if v is not None else q, mask ) class DenseBlock(hk.Module): def __init__(self, init_scale, widening_factor = 4, name = None): super().__init__(name = name) self._init_scale = init_scale self._widening_factor = widening_factor def __call__(self, x: jnp.ndarray) -> jnp.ndarray: hiddens = x.shape[-1] initializer = hk.initializers.VarianceScaling(self._init_scale) x = hk.Linear( self._widening_factor * hiddens, w_init=initializer )(x) x = jax.nn.gelu(x) return hk.Linear( hiddens, w_init=initializer )(x) class Transformer(hk.Module): def __init__(self, num_heads: int, num_layers: int, dropout_rate: float, name = None): super().__init__(name = name) self._num_heads = num_heads self._num_layers = num_layers self._dropout_rate = dropout_rate def __call__( self, h: jnp.ndarray, mask : Optional[jnp.ndarray] = None, is_training: bool = True ): init_scale = 2. / self._num_layers dropout_rate = self._dropout_rate if is_training else 0.0 if mask is not None: mask = mask[:, None, None, :] for i in range(self._num_layers): # attention block h_norm = hk.LayerNorm( axis=-1, create_scale=True, create_offset=True, name=f"h{i}_ln_1" )(h) h_attn = CausalSelfAttention( num_heads=self._num_heads, key_size = 32, w_init_scale = init_scale, value_size = None, model_size = h.shape[-1], name=f"h{i}_attn" )(h_norm, mask=mask) h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn) h = h + h_attn # dense block h_norm = hk.LayerNorm( axis=-1, create_scale=True, create_offset=True, name=f"h{i}_ln_2" )(h) h_dense = DenseBlock(init_scale, name=f"h{i}_dense")(h_norm) h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense) h = h + h_dense h = hk.LayerNorm( axis=-1, create_scale=True, create_offset=True, name=f"ln_f" )(h) return h # data def infinite_shuffle(iterable: Iterable, buffer_size: int): ds = itertools.cycle(iterable) buf = [next(ds) for _ in range(buffer_size)] random.shuffle(buf) while 1: item = next(ds) idx = random.randint(0, buffer_size - 1) # Inclusive result, buf[idx] = buf[idx], item yield result class Dataset: def __init__(self, path: str, batch_size: int, sequence_length: int): """Load a single-file ASCII dataset in memory.""" self.vocab_size = 128 self._batch_size = batch_size with open(path, 'r') as f: corpus = f.read() if not corpus.isascii(): raise ValueError('Loaded corpus is not ASCII.') if '\0' in corpus: # Reserve 0 codepoint for pad token. raise ValueError('Corpus must not contain null byte.') # Tokenize by taking ASCII codepoints. corpus = np.array([ord(c) for c in corpus]).astype(np.int32) assert np.min(corpus) > 0 assert np.max(corpus) < self.vocab_size # Double-checking ASCII codepoints. crop_len = sequence_length + 1 num_batches, ragged = divmod(corpus.size, batch_size * crop_len) if ragged: corpus = corpus[:-ragged] corpus = corpus.reshape([-1, crop_len]) if num_batches < 10: raise ValueError(f'Only {num_batches} batches; consider a shorter sequence or a smaller batch.') self._ds = infinite_shuffle(corpus, batch_size * 10) def __next__(self): """Yield next mini-batch.""" batch = [next(self._ds) for _ in range(self._batch_size)] batch = np.stack(batch) # Create the language modeling observation/target pairs. return dict(obs=batch[:, :-1], target=batch[:, 1:]) def __iter__(self): return self @staticmethod def decode(tokens: List[int]): return ''.join(chr(t) for t in tokens) def generate(forward_fn, config, num_steps: int, state: Dict[str, Any], text: str = 'This is a RSockClient.'): tokens = np.array([[ord(c) for c in text]]) for _ in range(num_steps): output = forward_fn(state["params"], state["rng"], {"obs": tokens[:config.m]}, is_training = False) out_tokens = output[:, -1].argmax(axis=-1) tokens = np.concatenate([tokens, [out_tokens]], axis=1) return Dataset.decode(tokens[0]) ################################################################################ # Training # ======== # haiku is pure functional the forward operations must be written down as # functions. This the structure of the code: # forward_fn -> jnp.ndarray # lm_loss_fn -> jnp.ndarray # On data: # Since haiku really really is functional, nothing can have a side effect. # Thus the data object that has to be stored has to be written in OOPs # style. # # On code: # There are two different inits ``__init__`` and ``init``, the # former is called when the updater is created and the latter is called # just before the loop. # # Then there is the ``update`` method, which is called on every iteration. # The simplicity of this approach! ################################################################################ import functools import optax import os import pickle def build_forward_fn(config: Dict): def _forward(data, is_training: bool = False) -> jnp.ndarray: """Forward pass.""" tokens = data['obs'] input_mask = jnp.greater(tokens, 0) seq_length = tokens.shape[1] # Embed the input tokens and positions. embed_init = hk.initializers.TruncatedNormal(stddev=0.02) token_embedding_map = hk.Embed(config.vocab_size, config.c, w_init=embed_init) token_embs = token_embedding_map(tokens) positional_embeddings = hk.get_parameter( 'pos_embs', [config.m, config.c], init = embed_init ) input_embeddings = token_embs + positional_embeddings[:seq_length] # Run the transformer over the inputs. transformer = Transformer( num_heads=config.num_heads, num_layers=config.num_layers, dropout_rate=config.dropout_rate ) output_embeddings = transformer( input_embeddings, input_mask, is_training ) # Reverse the embeddings (untied). return hk.Linear(config.vocab_size)(output_embeddings) return _forward def lm_loss_fn(forward_fn, vocab_size: int, params, rng, data: Dict[str, jnp.ndarray], is_training: bool = True) -> jnp.ndarray: """Compute the loss on data wrt params.""" logits = forward_fn(params, rng, data, is_training) targets = jax.nn.one_hot(data['target'], vocab_size) assert logits.shape == targets.shape mask = jnp.greater(data['obs'], 0) loss = -jnp.sum(targets * jax.nn.log_softmax(logits), axis=-1) loss = jnp.sum(loss * mask) / jnp.sum(mask) return loss class ParamUpdater: """A stateless abstraction around an init_fn/update_fn pair. This extracts some common boilerplate from the training loop. """ def __init__(self, net_init, loss_fn, optimizer: optax.GradientTransformation): self._net_init = net_init self._loss_fn = loss_fn self._opt = optimizer @functools.partial(jax.jit, static_argnums=0) def init(self, rng, data) -> Dict: """Initializes state of the updater.""" out_rng, init_rng = jax.random.split(rng) params = self._net_init(init_rng, data) opt_state = self._opt.init(params) out = dict( step=np.array(0), rng=out_rng, opt_state=opt_state, params=params, ) return out @functools.partial(jax.jit, static_argnums=0) def update(self, state: Dict[str, Any], data: Dict[str, jnp.ndarray]) -> Tuple[Dict, Dict]: """Updates the state using some data and returns metrics.""" rng, new_rng = jax.random.split(state['rng']) params = state['params'] loss, g = jax.value_and_grad(self._loss_fn)(params, rng, data) updates, opt_state = self._opt.update(g, state['opt_state']) params = optax.apply_updates(params, updates) new_state = { 'step': state['step'] + 1, 'rng': new_rng, 'opt_state': opt_state, 'params': params, } metrics = { 'step': state['step'], 'loss': loss, } return new_state, metrics class CheckpointingUpdater: """A didactic checkpointing wrapper around an Updater. A more mature checkpointing implementation might: - Use np.savez() to store the core data instead of pickle. - Not block JAX async dispatch. - Automatically garbage collect old checkpoints. Again since haiku is functional anything that has to be stored is written in OOPs. """ def __init__(self, inner: ParamUpdater, checkpoint_dir: str, checkpoint_every_n: int = 10000): self._inner = inner self._checkpoint_dir = checkpoint_dir self._checkpoint_every_n = checkpoint_every_n def _checkpoint_paths(self): return [p for p in os.listdir(self._checkpoint_dir) if 'checkpoint_' in p] def init(self, rng, data): """Initialize experiment state.""" if not os.path.exists(self._checkpoint_dir) or not self._checkpoint_paths(): os.makedirs(self._checkpoint_dir, exist_ok=True) return self._inner.init(rng, data) else: checkpoint = os.path.join(self._checkpoint_dir, max(self._checkpoint_paths())) logging.info('Loading checkpoint from %s', checkpoint) with open(checkpoint, 'rb') as f: state = pickle.load(f) return state def update(self, state, data): """Update experiment state.""" # NOTE: This blocks until `state` is computed. If you want to use JAX async # dispatch, maintain state['step'] as a NumPy scalar instead of a JAX array. # Context: https://jax.readthedocs.io/en/latest/async_dispatch.html step = np.array(state['step']) if step % self._checkpoint_every_n == 0: path = os.path.join(self._checkpoint_dir, 'checkpoint_{:07d}.pkl'.format(step)) checkpoint_state = jax.device_get(state) logging.info('Serializing experiment state to %s', path) with open(path, 'wb') as f: pickle.dump(checkpoint_state, f) state, out = self._inner.update(state, data) return state, out ################################################################################ # Main # ==== # haiku is pure functional the forward operations must be written down as # functions. This the structure of the code: # forward_fn -> jnp.ndarray # lm_loss_fn -> jnp.ndarray ################################################################################ class Config(dict): __getattr__ = dict.__getitem__ __setattr__ = dict.__setitem__ def main( filepath: str, m: int = 64, c: int = 8, num_heads: int = 8, num_layers: int = 6, batch_size: int = 8, dropout_rate: int = 0.1, grad_clip_value: float = 1.0, learning_rate: float = 0.001, checkpoint_dir: str = './checkpoints', max_steps: int = 10000, log_every: int = 1000 ): """Train an ASCII language model on filepath""" config = Config( vocab_size = 128, # fixed for ASCII filepath = filepath, m = m, c = c, num_heads = num_heads, num_layers = num_layers, batch_size = batch_size, dropout_rate = dropout_rate, grad_clip_value = grad_clip_value, learning_rate = learning_rate, checkpoint_dir = checkpoint_dir, ) train_dataset = Dataset( path = filepath, batch_size = batch_size, sequence_length = m ) # Set up the model, loss, and updater. forward_fn = hk.transform(build_forward_fn(config)) generate_fn = functools.partial(generate, forward_fn.apply, config) loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, config.vocab_size) optimizer = optax.chain( optax.clip_by_global_norm(config.grad_clip_value), optax.adam(config.learning_rate, b1=0.9, b2=0.99) ) updater = ParamUpdater(forward_fn.init, loss_fn, optimizer) updater = CheckpointingUpdater(updater, config.checkpoint_dir) # Initialize parameters. logging.info('Initializing parameters...') rng = jax.random.PRNGKey(428) data = next(train_dataset) state = updater.init(rng, data) logging.info('Starting train loop...') prev_time = time() pbar = tqdm(range(max_steps)) for step in pbar: data = next(train_dataset) # print({k:v.shape for k,v in data.items()}) state, metrics = updater.update(state, data) # We use JAX runahead to mask data preprocessing and JAX dispatch overheads. # Using values from state/metrics too often will block the runahead and can # cause these overheads to become more prominent. if step % log_every == 0: steps_per_sec = log_every / (time() - prev_time) prev_time = time() metrics.update({'steps_per_sec': steps_per_sec}) # generate a sample sample = generate_fn(32, state) logging.info({k: float(v) for k, v in metrics.items()}) logging.info('Generated sample: %s', sample) if __name__ == '__main__': import fire fire.Fire(main)