Created
September 2, 2024 09:27
-
-
Save umaxfun/f181d6a6c0f5f71c23a7f46937ee0d30 to your computer and use it in GitHub Desktop.
LLama impementation in pure python (not mine, original link is lost, if you know it -- drop a comment)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from dataclasses import dataclass | |
| from typing import Optional, Tuple | |
| import torch | |
| from torch import nn | |
| from tokenizer import Tokenizer | |
| from torch.nn import functional as F | |
| @dataclass | |
| class ModelArgs: | |
| dim: int = 4096 | |
| n_kv_heads: int = 8 | |
| vocab_size: int = 128256 # 128000 BPE merges + 256 bytes tokens | |
| n_layers: int = 32 | |
| n_heads: int = 32 | |
| ffn_dim_multiplier: float = 1.3 | |
| multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 | |
| norm_eps: float = 1e-5 | |
| rope_theta: float = 500000 | |
| max_seq_len: int = 2048 | |
| class LlamaRotaryEmbedding(nn.Module): | |
| def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): | |
| super().__init__() | |
| self.scaling_factor = scaling_factor | |
| self.dim = dim | |
| self.max_position_embeddings = max_position_embeddings | |
| self.base = base | |
| inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) | |
| self.register_buffer("inv_freq", inv_freq, persistent=False) | |
| # For BC we register cos and sin cached | |
| self.max_seq_len_cached = max_position_embeddings | |
| @torch.no_grad() | |
| def forward(self, x, position_ids): | |
| # x: [bs, num_attention_heads, seq_len, head_size] | |
| inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) | |
| position_ids_expanded = position_ids[:, None, :].float() | |
| # Force float32 since bfloat16 loses precision on long contexts | |
| # See https://github.com/huggingface/transformers/pull/29285 | |
| device_type = x.device.type | |
| device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" | |
| with torch.autocast(device_type=device_type, enabled=False): | |
| freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) | |
| emb = torch.cat((freqs, freqs), dim=-1) | |
| cos = emb.cos() | |
| sin = emb.sin() | |
| return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) | |
| def rotate_half(x): | |
| """Rotates half the hidden dims of the input.""" | |
| x1 = x[..., : x.shape[-1] // 2] | |
| x2 = x[..., x.shape[-1] // 2 :] | |
| return torch.cat((-x2, x1), dim=-1) | |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): | |
| cos = cos.unsqueeze(unsqueeze_dim) | |
| sin = sin.unsqueeze(unsqueeze_dim) | |
| q_embed = (q * cos) + (rotate_half(q) * sin) | |
| k_embed = (k * cos) + (rotate_half(k) * sin) | |
| return q_embed, k_embed | |
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: | |
| batch, num_key_value_heads, slen, head_dim = hidden_states.shape | |
| if n_rep == 1: | |
| return hidden_states | |
| hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) | |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| def _norm(self, x): | |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| def forward(self, x): | |
| output = self._norm(x.float()).type_as(x) | |
| return output * self.weight | |
| class Attention(nn.Module): | |
| def __init__(self, model_args: ModelArgs) -> None: | |
| super().__init__() | |
| self.dim, self.n_heads = model_args.dim, model_args.n_heads | |
| self.head_dim = model_args.dim // model_args.n_heads | |
| self.n_kv_heads = model_args.n_kv_heads | |
| self.n_rep = self.n_heads // self.n_kv_heads | |
| self.q_proj = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False) | |
| self.k_proj = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) | |
| self.v_proj = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) | |
| self.o_proj = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) | |
| self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=model_args.max_seq_len, base=model_args.rope_theta) | |
| def forward(self, x, pos_ids): | |
| bs, seqlen, _ = x.shape | |
| xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x) | |
| xq = xq.view(bs, seqlen, self.n_heads, self.head_dim).transpose(1, 2) | |
| xk = xk.view(bs, seqlen, self.n_kv_heads, self.head_dim).transpose(1, 2) | |
| xv = xv.view(bs, seqlen, self.n_kv_heads, self.head_dim).transpose(1, 2) | |
| cos, sin = self.rotary_emb(xv, pos_ids) | |
| xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin) | |
| # repeat k/v heads if n_kv_heads < n_heads | |
| xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) | |
| xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) | |
| # we use casual mask for training | |
| output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) | |
| output = output.transpose( | |
| 1, 2 | |
| ).contiguous() # (bs, seqlen, n_local_heads, head_dim) | |
| output = output.view(bs, seqlen, -1) | |
| return self.o_proj(output) | |
| class MLP(nn.Module): | |
| def __init__(self, model_args: ModelArgs) -> None: | |
| super().__init__() | |
| hidden_dim = int(2 * model_args.dim * 4 / 3) | |
| hidden_dim = int(model_args.ffn_dim_multiplier * hidden_dim) | |
| hidden_dim = model_args.multiple_of * ((hidden_dim + model_args.multiple_of - 1) // model_args.multiple_of) | |
| self.gate_proj = nn.Linear(model_args.dim, hidden_dim, bias=False) | |
| self.up_proj = nn.Linear(model_args.dim, hidden_dim, bias=False) | |
| self.down_proj = nn.Linear(hidden_dim, model_args.dim, bias=False) | |
| def forward(self, x): | |
| return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, layer_id: int, model_args: ModelArgs) -> None: | |
| super().__init__() | |
| self.self_attn = Attention(model_args) | |
| self.mlp = MLP(model_args) | |
| self.input_layernorm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps) | |
| self.post_attention_layernorm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps) | |
| def forward(self, x, pos_ids): | |
| h = x + self.self_attn(self.input_layernorm(x), pos_ids) | |
| out = h + self.mlp(self.post_attention_layernorm(h)) | |
| return out | |
| class GPT(nn.Module): | |
| def __init__(self, model_args: ModelArgs) -> None: | |
| super().__init__() | |
| self.embed_tokens = nn.Embedding(model_args.vocab_size, model_args.dim) | |
| self.layers = nn.ModuleDict() | |
| for layer_id in range(model_args.n_layers): | |
| self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) | |
| self.norm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps) | |
| self.lm_head = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) | |
| def forward(self, x): | |
| bs, seqlen = x.shape | |
| pos_ids = torch.arange(seqlen, device=x.device).unsqueeze(0).expand(bs, -1) | |
| h = self.embed_tokens(x) | |
| for layer in self.layers.values(): | |
| # h = layer(h, self.freqs_cis) | |
| h = layer(h, pos_ids) | |
| h = self.norm(h) | |
| output = self.lm_head(h) | |
| return output | |
| @classmethod | |
| def from_pretrained(cls, model_type): | |
| config = ModelArgs() | |
| model = GPT(config) | |
| sd = model.state_dict() | |
| sd_keys = sd.keys() | |
| # init a huggingface/transformers model | |
| from transformers import AutoModelForCausalLM | |
| model_hf = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") | |
| sd_hf = model_hf.state_dict() | |
| # copy while ensuring all of the parameters are aligned and match in names and shapes | |
| sd_keys_hf = sd_hf.keys() | |
| assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}" | |
| for k in sd_keys_hf: | |
| # vanilla copy over the other parameters | |
| assert sd_hf[k].shape == sd[k.replace('model.', '')].shape | |
| with torch.no_grad(): | |
| sd[k.replace('model.', '')].copy_(sd_hf[k]) | |
| return model | |
| # --------------------------------------------------------------------------------------------------- # | |
| num_return_sequences = 1 | |
| max_length = 100 | |
| # model = GPT(ModelArgs()) | |
| model = GPT.from_pretrained("llama3") | |
| # print model layers | |
| sd = model.state_dict() | |
| for k, v in sd.items(): | |
| print(k, v.shape) | |
| model.eval() | |
| model.cuda() | |
| # prefix tokens | |
| enc = Tokenizer(model_path="llama3/tokenizer.model") | |
| tokens = enc.encode("Hello, I'm a language model,", bos=False, eos=False) | |
| tokens = torch.tensor(tokens, dtype=torch.long) | |
| tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1) | |
| x = tokens.to('cuda') | |
| torch.manual_seed(42) | |
| torch.cuda.manual_seed(42) | |
| while x.size(1) < max_length: | |
| with torch.no_grad(): | |
| logits = model(x) | |
| logits = logits[:, -1, :] | |
| probs = F.softmax(logits, dim=-1) | |
| topk_probs, topk_indices = torch.topk(probs, 1, dim=-1) | |
| ix = torch.multinomial(topk_probs, num_samples=1) | |
| xcol = torch.gather(topk_indices, -1, ix) | |
| x = torch.cat((x, xcol), dim=1) | |
| for i in range(num_return_sequences): | |
| tokens = x[i, :max_length].tolist() | |
| try: | |
| # Try to find the index of token 128009 | |
| index = tokens.index(128009) | |
| # Cut off all tokens from this index onward | |
| tokens = tokens[:index] | |
| except ValueError: | |
| # Handle the case where 128009 is not in the list | |
| print("Token 128009 is not in the list. No changes made.") | |
| decoded = enc.decode(tokens) | |
| print(">", decoded) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment