Skip to content

Instantly share code, notes, and snippets.

@umaxfun
Created September 2, 2024 09:27
Show Gist options
  • Select an option

  • Save umaxfun/f181d6a6c0f5f71c23a7f46937ee0d30 to your computer and use it in GitHub Desktop.

Select an option

Save umaxfun/f181d6a6c0f5f71c23a7f46937ee0d30 to your computer and use it in GitHub Desktop.
LLama impementation in pure python (not mine, original link is lost, if you know it -- drop a comment)
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from torch import nn
from tokenizer import Tokenizer
from torch.nn import functional as F
@dataclass
class ModelArgs:
dim: int = 4096
n_kv_heads: int = 8
vocab_size: int = 128256 # 128000 BPE merges + 256 bytes tokens
n_layers: int = 32
n_heads: int = 32
ffn_dim_multiplier: float = 1.3
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
norm_eps: float = 1e-5
rope_theta: float = 500000
max_seq_len: int = 2048
class LlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
super().__init__()
self.scaling_factor = scaling_factor
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# For BC we register cos and sin cached
self.max_seq_len_cached = max_position_embeddings
@torch.no_grad()
def forward(self, x, position_ids):
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
class Attention(nn.Module):
def __init__(self, model_args: ModelArgs) -> None:
super().__init__()
self.dim, self.n_heads = model_args.dim, model_args.n_heads
self.head_dim = model_args.dim // model_args.n_heads
self.n_kv_heads = model_args.n_kv_heads
self.n_rep = self.n_heads // self.n_kv_heads
self.q_proj = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=model_args.max_seq_len, base=model_args.rope_theta)
def forward(self, x, pos_ids):
bs, seqlen, _ = x.shape
xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
xq = xq.view(bs, seqlen, self.n_heads, self.head_dim).transpose(1, 2)
xk = xk.view(bs, seqlen, self.n_kv_heads, self.head_dim).transpose(1, 2)
xv = xv.view(bs, seqlen, self.n_kv_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(xv, pos_ids)
xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin)
# repeat k/v heads if n_kv_heads < n_heads
xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
# we use casual mask for training
output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
output = output.transpose(
1, 2
).contiguous() # (bs, seqlen, n_local_heads, head_dim)
output = output.view(bs, seqlen, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, model_args: ModelArgs) -> None:
super().__init__()
hidden_dim = int(2 * model_args.dim * 4 / 3)
hidden_dim = int(model_args.ffn_dim_multiplier * hidden_dim)
hidden_dim = model_args.multiple_of * ((hidden_dim + model_args.multiple_of - 1) // model_args.multiple_of)
self.gate_proj = nn.Linear(model_args.dim, hidden_dim, bias=False)
self.up_proj = nn.Linear(model_args.dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, model_args.dim, bias=False)
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, model_args: ModelArgs) -> None:
super().__init__()
self.self_attn = Attention(model_args)
self.mlp = MLP(model_args)
self.input_layernorm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps)
self.post_attention_layernorm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps)
def forward(self, x, pos_ids):
h = x + self.self_attn(self.input_layernorm(x), pos_ids)
out = h + self.mlp(self.post_attention_layernorm(h))
return out
class GPT(nn.Module):
def __init__(self, model_args: ModelArgs) -> None:
super().__init__()
self.embed_tokens = nn.Embedding(model_args.vocab_size, model_args.dim)
self.layers = nn.ModuleDict()
for layer_id in range(model_args.n_layers):
self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args)
self.norm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps)
self.lm_head = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
def forward(self, x):
bs, seqlen = x.shape
pos_ids = torch.arange(seqlen, device=x.device).unsqueeze(0).expand(bs, -1)
h = self.embed_tokens(x)
for layer in self.layers.values():
# h = layer(h, self.freqs_cis)
h = layer(h, pos_ids)
h = self.norm(h)
output = self.lm_head(h)
return output
@classmethod
def from_pretrained(cls, model_type):
config = ModelArgs()
model = GPT(config)
sd = model.state_dict()
sd_keys = sd.keys()
# init a huggingface/transformers model
from transformers import AutoModelForCausalLM
model_hf = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
sd_hf = model_hf.state_dict()
# copy while ensuring all of the parameters are aligned and match in names and shapes
sd_keys_hf = sd_hf.keys()
assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
for k in sd_keys_hf:
# vanilla copy over the other parameters
assert sd_hf[k].shape == sd[k.replace('model.', '')].shape
with torch.no_grad():
sd[k.replace('model.', '')].copy_(sd_hf[k])
return model
# --------------------------------------------------------------------------------------------------- #
num_return_sequences = 1
max_length = 100
# model = GPT(ModelArgs())
model = GPT.from_pretrained("llama3")
# print model layers
sd = model.state_dict()
for k, v in sd.items():
print(k, v.shape)
model.eval()
model.cuda()
# prefix tokens
enc = Tokenizer(model_path="llama3/tokenizer.model")
tokens = enc.encode("Hello, I'm a language model,", bos=False, eos=False)
tokens = torch.tensor(tokens, dtype=torch.long)
tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1)
x = tokens.to('cuda')
torch.manual_seed(42)
torch.cuda.manual_seed(42)
while x.size(1) < max_length:
with torch.no_grad():
logits = model(x)
logits = logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
topk_probs, topk_indices = torch.topk(probs, 1, dim=-1)
ix = torch.multinomial(topk_probs, num_samples=1)
xcol = torch.gather(topk_indices, -1, ix)
x = torch.cat((x, xcol), dim=1)
for i in range(num_return_sequences):
tokens = x[i, :max_length].tolist()
try:
# Try to find the index of token 128009
index = tokens.index(128009)
# Cut off all tokens from this index onward
tokens = tokens[:index]
except ValueError:
# Handle the case where 128009 is not in the list
print("Token 128009 is not in the list. No changes made.")
decoded = enc.decode(tokens)
print(">", decoded)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment