import torch import torch.nn as nn import torch.nn.functional as F # helpers def make_unit_length(x, epsilon=1e-6): norm = x.norm(p=2, dim=-1, keepdim=True) return x.div(norm + epsilon) def sort_key_val(t1, t2, dim=-1): values, indices = t1.sort(dim=dim) t2 = t2.expand_as(t1) return values, t2.gather(dim, indices) def batched_index_select(values, indices): b = values.shape[0] return values[torch.arange(0, b), indices.transpose(0, 1)].transpose(0, 1) # reversible net helper classes class ReversibleBlock(nn.Module): def __init__(self, f_block, g_block, dim = 1): super().__init__() self.dim = dim self.f_block = f_block self.g_block = g_block def forward(self, x): x1, x2 = torch.chunk(x, 2, dim=self.dim) y1, y2 = None, None with torch.no_grad(): y1 = x1 + self.f_block(x2) y2 = x2 + self.g_block(y1) return torch.cat([y1, y2], dim=self.dim) def backward_pass(self, y, dy): y1, y2 = torch.chunk(y, 2, dim=self.dim) del y dy1, dy2 = torch.chunk(dy, 2, dim=self.dim) del dy y1.requires_grad = True y2.requires_grad = True with torch.enable_grad(): gy1 = self.g_block(y1) gy1.backward(dy2) with torch.no_grad(): x2 = y2 - gy1 del y2, gy1 dx1 = dy1 + y1.grad del dy1 y1.grad = None with torch.enable_grad(): x2.requires_grad = True fx2 = self.f_block(x2) fx2.backward(dx1) with torch.no_grad(): x1 = y1 - fx2 del y1, fx2 dx2 = dy2 + x2.grad del dy2 x2.grad = None x = torch.cat([x1, x2.detach()], dim=self.dim) dx = torch.cat([dx1, dx2], dim=self.dim) return x, dx class _ReversibleModuleFunction(torch.autograd.function.Function): @staticmethod def forward(ctx, x, reversible_blocks): for block in reversible_blocks: x = block(x) ctx.y = x.detach() ctx.reversible_blocks = reversible_blocks return x @staticmethod def backward(ctx, dy): y = ctx.y del ctx.y for i in range(len(ctx.reversible_blocks) - 1, -1, -1): y, dy = ctx.reversible_blocks[i].backward_pass(y, dy) del ctx.reversible_blocks return dy, None class ReversibleSequence(nn.Module): def __init__(self, reversible_blocks): super().__init__() self.reversible_blocks = reversible_blocks def forward(self, x): x = _ReversibleModuleFunction.apply(x, self.reversible_blocks) return x # lsh attention class LSHAttention(nn.Module): def __init__( self, dropout = 0., bucket_size = 64, n_hashes = 8, allow_duplicate_attention = False, attend_across_buckets = False, rehash_each_round = True, drop_for_hash_rate = 0.0): super().__init__() if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') self.dropout = nn.Dropout(dropout) self.dropout_for_hash = nn.Dropout(drop_for_hash_rate) assert rehash_each_round or allow_duplicate_attention, ( 'The setting {allow_duplicate_attention=False, rehash_each_round=False}' ' is not implemented.') self.n_hashes = n_hashes self.bucket_size = bucket_size self._allow_duplicate_attention = allow_duplicate_attention self._attend_across_buckets = attend_across_buckets self._rehash_each_round = rehash_each_round def _sample_rotation(self, shape, vecs): device = vecs.device return torch.randn(shape, device=device) def hash_vectors(self, n_buckets, vecs): batch_size = vecs.shape[0] device = vecs.device # See https://arxiv.org/pdf/1509.02897.pdf # We sample a different random rotation for each round of hashing to # decrease the probability of hash misses. assert n_buckets % 2 == 0 rot_size = n_buckets rotations_shape = ( vecs.shape[-1], self.n_hashes if self._rehash_each_round else 1, rot_size // 2) random_rotations = self._sample_rotation(rotations_shape, vecs) dropped_vecs = self.dropout_for_hash(vecs) rotated_vecs = torch.einsum('btf,fhi->bhti', dropped_vecs, random_rotations) if self._rehash_each_round: rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1) buckets = torch.argmax(rotated_vecs, axis=-1) # buckets is now (self.n_hashes, seqlen). Next we add offsets so that # bucket numbers from different hashing rounds don't overlap. offsets = torch.arange(self.n_hashes, device=device) offsets = torch.reshape(offsets * n_buckets, (1, -1, 1)) buckets = torch.reshape(buckets + offsets, (batch_size, -1,)) else: assert not self._factorize_hash rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1) # In this configuration, we map each item to the top self.n_hashes buckets rotated_vecs = torch.squeeze(rotated_vecs, 0) bucket_range = torch.arange(0, rotated_vecs.shape[-1], device=device) bucket_range = torch.reshape(bucket_range, (1, -1)) bucket_range = bucket_range.expand_as(rotated_vecs.shape) _, buckets = sort_key_val(rotated_vecs, bucket_range, dim=-1) buckets = buckets[:, -self.n_hashes:] h, *_ = buckets.shape buckets = torch.reshape(buckets.permute((*_, h)), (-1,)) return buckets def forward(self, qk, v): batch_size, seqlen, _ = qk.shape device = qk.device n_buckets = seqlen // self.bucket_size n_bins = n_buckets buckets = self.hash_vectors(n_buckets, qk) # We use the same vector as both a query and a key. assert int(buckets.shape[1]) == self.n_hashes * seqlen ticker = torch.arange(0, self.n_hashes * seqlen, device=device).unsqueeze(0) buckets_and_t = seqlen * buckets + (ticker % seqlen) buckets_and_t = buckets_and_t.detach() # Hash-based sort ("s" at the start of variable names means "sorted") sbuckets_and_t, sticker = sort_key_val(buckets_and_t, ticker, dim=-1) _, undo_sort = sort_key_val(sticker, ticker, dim=-1) sbuckets_and_t = sbuckets_and_t.detach() sticker = sticker.detach() undo_sort = undo_sort.detach() st = (sticker % seqlen) sqk = batched_index_select(qk, st) sv = batched_index_select(v, st) # Split off a "bin" axis so that attention only occurs within chunks. bq_t = bkv_t = torch.reshape(st, (batch_size, self.n_hashes * n_bins, -1)) bqk = torch.reshape(sqk, (batch_size, self.n_hashes * n_bins, -1, sqk.shape[-1])) bv = torch.reshape(sv, (batch_size, self.n_hashes * n_bins, -1, sv.shape[-1])) bq_buckets = bkv_buckets = torch.reshape(sbuckets_and_t // seqlen, (batch_size, self.n_hashes * n_bins, -1)) # Hashing operates on unit-length vectors. Unnormalized query vectors are # fine because they effectively provide a learnable temperature for the # attention softmax, but normalizing keys is needed so that similarity for # the purposes of attention correctly corresponds to hash locality. bq = bqk bk = make_unit_length(bqk) # Allow each chunk to attend within itself, and also one chunk back. Chunk # boundaries might occur in the middle of a sequence of items from the # same bucket, so this increases the chances of attending to relevant items. def look_one_back(x): x_extra = torch.cat([x[:, -1:, ...], x[:, :-1, ...]], dim=1) return torch.cat([x, x_extra], dim=2) bk = look_one_back(bk) bv = look_one_back(bv) bkv_t = look_one_back(bkv_t) bkv_buckets = look_one_back(bkv_buckets) # Dot-product attention. dots = torch.einsum('bhie,bhje->bhij', bq, bk) / (bq.shape[-1] ** -0.5) # Causal masking mask = bq_t[:, :, :, None] < bkv_t[:, :, None, :] dots = dots - 1e9 * mask # Mask out attention to self except when no other targets are available. self_mask = bq_t[:, :, :, None] == bkv_t[:, :, None, :] dots = dots - 1e5 * self_mask # Mask out attention to other hash buckets. if not self._attend_across_buckets: bucket_mask = bq_buckets[:, :, :, None] != bkv_buckets[:, :, None, :] dots = dots - 1e7 * bucket_mask # Don't double-count query-key pairs across multiple rounds of hashing. # There are two possible strategies here. (1) The default is to count how # many times a query-key pair is repeated, and to lower its log-prob # correspondingly at each repetition. (2) When hard_k is set, the code # instead masks all but the first occurence of each query-key pair. if not self._allow_duplicate_attention: locs1 = undo_sort // bq_t.shape[-1] locs2 = (locs1 + 1) % (self.n_hashes * n_bins) if not self._attend_across_buckets: locs1 = buckets * (self.n_hashes * n_bins) + locs1 locs2 = buckets * (self.n_hashes * n_bins) + locs2 locs = torch.cat([ torch.reshape(locs1, (batch_size, self.n_hashes, seqlen)), torch.reshape(locs2, (batch_size, self.n_hashes, seqlen)), ], 1).permute((0, 2, 1)) slocs = batched_index_select(locs, st) b_locs = torch.reshape(slocs, (batch_size, self.n_hashes * n_bins, -1, 2 * self.n_hashes)) b_locs1 = b_locs[:, :, :, None, :self.n_hashes] bq_locs = b_locs1.expand(b_locs.shape[:3] + (2, self.n_hashes)) bq_locs = torch.reshape(bq_locs, b_locs.shape) bkv_locs = look_one_back(b_locs) dup_counts = (bq_locs[:, :, :, None, :] == bkv_locs[:, :, None, :, :]).float().sum(dim=-1) dup_counts = dup_counts.detach() assert dup_counts.shape == dots.shape dots = dots - torch.log(dup_counts + 1e-9) # Softmax. dots_logsumexp = torch.logsumexp(dots, dim=-1, keepdim=True) dots = torch.exp(dots - dots_logsumexp) dots = self.dropout(dots) bo = torch.einsum('buij,buje->buie', dots, bv) so = torch.reshape(bo, (batch_size, -1, bo.shape[-1])) slogits = torch.reshape(dots_logsumexp, (batch_size, -1,)) o = batched_index_select(so, undo_sort) _, logits = sort_key_val(sticker, slogits, dim=-1) if self.n_hashes == 1: out = o else: o = torch.reshape(o, (batch_size, self.n_hashes, seqlen, o.shape[-1])) logits = torch.reshape(logits, (batch_size, self.n_hashes, seqlen, 1)) probs = torch.exp(logits - torch.logsumexp(logits, dim=1, keepdims=True)) out = torch.sum(o * probs, dim=1) assert out.shape == v.shape return out class LSHSelfAttention(nn.Module): def __init__(self, emb, heads = 8, bucket_size = 64, n_hashes = 8, **kwargs): super().__init__() self.heads = heads self.toqk = nn.Linear(emb, emb * heads) self.tov = nn.Linear(emb, emb * heads) self.unify_heads = nn.Linear(emb * heads, emb) self.bucket_size = bucket_size self.lsh_attn = LSHAttention(bucket_size=bucket_size, **kwargs) def forward(self, x): b, t, e, h = *x.shape, self.heads assert t % self.bucket_size == 0, f'Sequence length needs to be divisible by target bucket size - {self.bucket_size}' qk = self.toqk(x) v = self.tov(x) def merge_heads(v): return v.view(b, t, h, e).transpose(1, 2).reshape(b * h, t, e) def split_heads(v): return v.view(b, h, t, e).transpose(1, 2).contiguous() qk = merge_heads(qk) v = merge_heads(v) attn_out = self.lsh_attn(qk, v) out = split_heads(attn_out).view(b, t, h * e) return self.unify_heads(out) # feedforward with chunking class FeedForward(nn.Module): def __init__(self, emb, mult = 4): super().__init__() self.emb = emb self.proj_in = nn.Linear(emb, emb * mult) self.proj_out = nn.Linear(emb * mult, emb) def forward(self, x): x = self.proj_in(x) x = F.gelu(x) x = self.proj_out(x) return x class WithLayerNorm(nn.Module): def __init__(self, emb, fn): super().__init__() self.emb = emb self.norm = nn.LayerNorm(emb) self.fn = fn def forward(self, x): x = self.norm(x) return self.fn(x) class Chunk(nn.Module): def __init__(self, chunks, fn, dim = -1): super().__init__() self.dim = dim self.chunks = chunks self.fn = fn def forward(self, x): chunks = x.chunk(self.chunks, dim = self.dim) return torch.cat([self.fn(c) for c in chunks], dim = self.dim) # reformer auto-regressive lm class Reformer(nn.Module): def __init__(self, emb, depth, max_seq_len, num_tokens = 10000, heads = 8, bucket_size = 64, n_hashes = 8, ff_chunks = 100): super().__init__() self.emb = emb self.depth = depth self.token_emb = nn.Embedding(num_tokens, emb) self.pos_emb = nn.Embedding(max_seq_len, emb) blocks = [] for _ in range(depth): f = WithLayerNorm(emb, LSHSelfAttention(emb, heads, bucket_size, n_hashes)) g = Chunk(ff_chunks, WithLayerNorm(emb, FeedForward(emb)), dim = -2) blocks.append(ReversibleBlock(f, g, dim=-1)) self.layers = ReversibleSequence(nn.ModuleList(blocks)) self.to_logits = nn.Linear(emb, num_tokens) def forward(self, x): x = self.token_emb(x) + self.pos_emb(torch.arange(0, x.shape[1])) x = torch.cat([x, x], dim = -1) x = self.layers(x) x = torch.stack(x.chunk(2, dim=-1)).sum(dim=0) return self.to_logits(x) # testing num_tokens = 10000 r = Reformer( emb = 512, depth = 12, max_seq_len = 1024, num_tokens= num_tokens, heads = 8, bucket_size = 64, n_hashes = 8, ff_chunks = 200 ) x = torch.randint(0, num_tokens, (1, 1024)).long() y = r(x)