@@ -0,0 +1,403 @@
import torch
import torch .nn as nn
import torch .nn .functional as F
class ReversibleBlock (nn .Module ):
def __init__ (self , f_block , g_block , dim = 1 ):
super ().__init__ ()
self .dim = dim
self .f_block = f_block
self .g_block = g_block
def forward (self , x ):
x1 , x2 = torch .chunk (x , 2 , dim = self .dim )
y1 , y2 = None , None
with torch .no_grad ():
y1 = x1 + self .f_block (x2 )
y2 = x2 + self .g_block (y1 )
return torch .cat ([y1 , y2 ], dim = self .dim )
def backward_pass (self , y , dy ):
y1 , y2 = torch .chunk (y , 2 , dim = self .dim )
del y
dy1 , dy2 = torch .chunk (dy , 2 , dim = self .dim )
del dy
y1 .requires_grad = True
y2 .requires_grad = True
with torch .enable_grad ():
gy1 = self .g_block (y1 )
gy1 .backward (dy2 )
with torch .no_grad ():
x2 = y2 - gy1
del y2 , gy1
dx1 = dy1 + y1 .grad
del dy1
y1 .grad = None
with torch .enable_grad ():
x2 .requires_grad = True
fx2 = self .f_block (x2 )
fx2 .backward (dx1 )
with torch .no_grad ():
x1 = y1 - fx2
del y1 , fx2
dx2 = dy2 + x2 .grad
del dy2
x2 .grad = None
x = torch .cat ([x1 , x2 .detach ()], dim = self .dim )
dx = torch .cat ([dx1 , dx2 ], dim = self .dim )
return x , dx
class _ReversibleModuleFunction (torch .autograd .function .Function ):
@staticmethod
def forward (ctx , x , reversible_blocks ):
for block in reversible_blocks :
x = block (x )
ctx .y = x .detach ()
ctx .reversible_blocks = reversible_blocks
return x
@staticmethod
def backward (ctx , dy ):
y = ctx .y
del ctx .y
for i in range (len (ctx .reversible_blocks ) - 1 , - 1 , - 1 ):
y , dy = ctx .reversible_blocks [i ].backward_pass (y , dy )
del ctx .reversible_blocks
return dy , None
class ReversibleSequence (nn .Module ):
def __init__ (self , reversible_blocks ):
super ().__init__ ()
self .reversible_blocks = reversible_blocks
def forward (self , x ):
x = _ReversibleModuleFunction .apply (x , self .reversible_blocks )
return x
def make_unit_length (x , epsilon = 1e-6 ):
norm = x .norm (p = 2 , dim = - 1 , keepdim = True )
return x .div (norm + epsilon )
def sort_key_val (t1 , t2 , dim = - 1 ):
values , indices = t1 .sort (dim = dim )
t2 = t2 .expand_as (t1 )
return values , t2 .gather (dim , indices )
def batched_index_select (values , indices ):
b = values .shape [0 ]
return values [torch .arange (0 , b ), indices .transpose (0 , 1 )].transpose (0 , 1 )
class LSHAttention (nn .Module ):
def __init__ ( self ,
dropout = 0. ,
bucket_size = 64 ,
n_hashes = 8 ,
allow_duplicate_attention = False ,
attend_across_buckets = False ,
rehash_each_round = True ,
drop_for_hash_rate = 0.0 ):
super ().__init__ ()
if dropout >= 1.0 :
raise ValueError ('Dropout rates must be lower than 1.' )
self .dropout = nn .Dropout (dropout )
self .dropout_for_hash = nn .Dropout (drop_for_hash_rate )
assert rehash_each_round or allow_duplicate_attention , (
'The setting {allow_duplicate_attention=False, rehash_each_round=False}'
' is not implemented.' )
self .n_hashes = n_hashes
self .bucket_size = bucket_size
self ._allow_duplicate_attention = allow_duplicate_attention
self ._attend_across_buckets = attend_across_buckets
self ._rehash_each_round = rehash_each_round
def _sample_rotation (self , shape , vecs ):
device = vecs .device
return torch .randn (shape , device = device )
def hash_vectors (self , n_buckets , vecs ):
batch_size = vecs .shape [0 ]
device = vecs .device
# See https://arxiv.org/pdf/1509.02897.pdf
# We sample a different random rotation for each round of hashing to
# decrease the probability of hash misses.
assert n_buckets % 2 == 0
rot_size = n_buckets
rotations_shape = (
vecs .shape [- 1 ],
self .n_hashes if self ._rehash_each_round else 1 ,
rot_size // 2 )
random_rotations = self ._sample_rotation (rotations_shape , vecs )
dropped_vecs = self .dropout_for_hash (vecs )
rotated_vecs = torch .einsum ('btf,fhi->bhti' , dropped_vecs , random_rotations )
if self ._rehash_each_round :
rotated_vecs = torch .cat ([rotated_vecs , - rotated_vecs ], dim = - 1 )
buckets = torch .argmax (rotated_vecs , axis = - 1 )
# buckets is now (self.n_hashes, seqlen). Next we add offsets so that
# bucket numbers from different hashing rounds don't overlap.
offsets = torch .arange (self .n_hashes , device = device )
offsets = torch .reshape (offsets * n_buckets , (1 , - 1 , 1 ))
buckets = torch .reshape (buckets + offsets , (batch_size , - 1 ,))
else :
assert not self ._factorize_hash
rotated_vecs = torch .cat ([rotated_vecs , - rotated_vecs ], dim = - 1 )
# In this configuration, we map each item to the top self.n_hashes buckets
rotated_vecs = torch .squeeze (rotated_vecs , 0 )
bucket_range = torch .arange (0 , rotated_vecs .shape [- 1 ], device = device )
bucket_range = torch .reshape (bucket_range , (1 , - 1 ))
bucket_range = bucket_range .expand_as (rotated_vecs .shape )
_ , buckets = sort_key_val (rotated_vecs , bucket_range , dim = - 1 )
buckets = buckets [:, - self .n_hashes :]
h , * _ = buckets .shape
buckets = torch .reshape (buckets .permute ((* _ , h )), (- 1 ,))
return buckets
def forward (self , qk , v ):
batch_size , seqlen , _ = qk .shape
device = qk .device
n_buckets = seqlen // self .bucket_size
n_bins = n_buckets
buckets = self .hash_vectors (n_buckets , qk )
# We use the same vector as both a query and a key.
assert int (buckets .shape [1 ]) == self .n_hashes * seqlen
ticker = torch .arange (0 , self .n_hashes * seqlen , device = device ).unsqueeze (0 )
buckets_and_t = seqlen * buckets + (ticker % seqlen )
buckets_and_t = buckets_and_t .detach ()
# Hash-based sort ("s" at the start of variable names means "sorted")
sbuckets_and_t , sticker = sort_key_val (buckets_and_t , ticker , dim = - 1 )
_ , undo_sort = sort_key_val (sticker , ticker , dim = - 1 )
sbuckets_and_t = sbuckets_and_t .detach ()
sticker = sticker .detach ()
undo_sort = undo_sort .detach ()
st = (sticker % seqlen )
sqk = batched_index_select (qk , st )
sv = batched_index_select (v , st )
# Split off a "bin" axis so that attention only occurs within chunks.
bq_t = bkv_t = torch .reshape (st , (batch_size , self .n_hashes * n_bins , - 1 ))
bqk = torch .reshape (sqk , (batch_size , self .n_hashes * n_bins , - 1 , sqk .shape [- 1 ]))
bv = torch .reshape (sv , (batch_size , self .n_hashes * n_bins , - 1 , sv .shape [- 1 ]))
bq_buckets = bkv_buckets = torch .reshape (sbuckets_and_t // seqlen , (batch_size , self .n_hashes * n_bins , - 1 ))
# Hashing operates on unit-length vectors. Unnormalized query vectors are
# fine because they effectively provide a learnable temperature for the
# attention softmax, but normalizing keys is needed so that similarity for
# the purposes of attention correctly corresponds to hash locality.
bq = bqk
bk = make_unit_length (bqk )
# Allow each chunk to attend within itself, and also one chunk back. Chunk
# boundaries might occur in the middle of a sequence of items from the
# same bucket, so this increases the chances of attending to relevant items.
def look_one_back (x ):
x_extra = torch .cat ([x [:, - 1 :, ...], x [:, :- 1 , ...]], dim = 1 )
return torch .cat ([x , x_extra ], dim = 2 )
bk = look_one_back (bk )
bv = look_one_back (bv )
bkv_t = look_one_back (bkv_t )
bkv_buckets = look_one_back (bkv_buckets )
# Dot-product attention.
dots = torch .einsum ('bhie,bhje->bhij' , bq , bk ) / (bq .shape [- 1 ] ** - 0.5 )
# Causal masking
mask = bq_t [:, :, :, None ] < bkv_t [:, :, None , :]
dots = dots - 1e9 * mask
# Mask out attention to self except when no other targets are available.
self_mask = bq_t [:, :, :, None ] == bkv_t [:, :, None , :]
dots = dots - 1e5 * self_mask
# Mask out attention to other hash buckets.
if not self ._attend_across_buckets :
bucket_mask = bq_buckets [:, :, :, None ] != bkv_buckets [:, :, None , :]
dots = dots - 1e7 * bucket_mask
# Don't double-count query-key pairs across multiple rounds of hashing.
# There are two possible strategies here. (1) The default is to count how
# many times a query-key pair is repeated, and to lower its log-prob
# correspondingly at each repetition. (2) When hard_k is set, the code
# instead masks all but the first occurence of each query-key pair.
if not self ._allow_duplicate_attention :
locs1 = undo_sort // bq_t .shape [- 1 ]
locs2 = (locs1 + 1 ) % (self .n_hashes * n_bins )
if not self ._attend_across_buckets :
locs1 = buckets * (self .n_hashes * n_bins ) + locs1
locs2 = buckets * (self .n_hashes * n_bins ) + locs2
locs = torch .cat ([
torch .reshape (locs1 , (batch_size , self .n_hashes , seqlen )),
torch .reshape (locs2 , (batch_size , self .n_hashes , seqlen )),
], 1 ).permute ((0 , 2 , 1 ))
slocs = batched_index_select (locs , st )
b_locs = torch .reshape (slocs , (batch_size , self .n_hashes * n_bins , - 1 , 2 * self .n_hashes ))
b_locs1 = b_locs [:, :, :, None , :self .n_hashes ]
bq_locs = b_locs1 .expand (b_locs .shape [:3 ] + (2 , self .n_hashes ))
bq_locs = torch .reshape (bq_locs , b_locs .shape )
bkv_locs = look_one_back (b_locs )
dup_counts = (bq_locs [:, :, :, None , :] == bkv_locs [:, :, None , :, :]).float ().sum (dim = - 1 )
dup_counts = dup_counts .detach ()
assert dup_counts .shape == dots .shape
dots = dots - torch .log (dup_counts + 1e-9 )
# Softmax.
dots_logsumexp = torch .logsumexp (dots , dim = - 1 , keepdim = True )
dots = torch .exp (dots - dots_logsumexp )
dots = self .dropout (dots )
bo = torch .einsum ('buij,buje->buie' , dots , bv )
so = torch .reshape (bo , (batch_size , - 1 , bo .shape [- 1 ]))
slogits = torch .reshape (dots_logsumexp , (batch_size , - 1 ,))
o = batched_index_select (so , undo_sort )
_ , logits = sort_key_val (sticker , slogits , dim = - 1 )
if self .n_hashes == 1 :
out = o
else :
o = torch .reshape (o , (batch_size , self .n_hashes , seqlen , o .shape [- 1 ]))
logits = torch .reshape (logits , (batch_size , self .n_hashes , seqlen , 1 ))
probs = torch .exp (logits - torch .logsumexp (logits , dim = 1 , keepdims = True ))
out = torch .sum (o * probs , dim = 1 )
assert out .shape == v .shape
return out
class LSHSelfAttention (nn .Module ):
def __init__ (self , emb , heads = 8 , bucket_size = 64 , n_hashes = 8 , ** kwargs ):
super ().__init__ ()
self .heads = heads
self .toqk = nn .Linear (emb , emb * heads )
self .tov = nn .Linear (emb , emb * heads )
self .unify_heads = nn .Linear (emb * heads , emb )
self .bucket_size = bucket_size
self .lsh_attn = LSHAttention (bucket_size = bucket_size , ** kwargs )
def forward (self , x ):
b , t , e , h = * x .shape , self .heads
assert t % self .bucket_size == 0 , f'Sequence length needs to be divisible by target bucket size - { self .bucket_size } '
qk = self .toqk (x )
v = self .tov (x )
def merge_heads (v ):
return v .view (b , t , h , e ).transpose (1 , 2 ).reshape (b * h , t , e )
def split_heads (v ):
return v .view (b , h , t , e ).transpose (1 , 2 ).contiguous ()
qk = merge_heads (qk )
v = merge_heads (v )
attn_out = self .lsh_attn (qk , v )
out = split_heads (attn_out ).view (b , t , h * e )
return self .unify_heads (out )
class FeedForward (nn .Module ):
def __init__ (self , emb , mult = 4 ):
super ().__init__ ()
self .emb = emb
self .proj_in = nn .Linear (emb , emb * mult )
self .proj_out = nn .Linear (emb * mult , emb )
def forward (self , x ):
x = self .proj_in (x )
x = F .gelu (x )
x = self .proj_out (x )
return x
class WithLayerNorm (nn .Module ):
def __init__ (self , emb , fn ):
super ().__init__ ()
self .emb = emb
self .norm = nn .LayerNorm (emb )
self .fn = fn
def forward (self , x ):
x = self .norm (x )
return self .fn (x )
class Chunk (nn .Module ):
def __init__ (self , chunks , fn , dim = - 1 ):
super ().__init__ ()
self .dim = dim
self .chunks = chunks
self .fn = fn
def forward (self , x ):
chunks = x .chunk (self .chunks , dim = self .dim )
return torch .cat ([self .fn (c ) for c in chunks ], dim = self .dim )
class Reformer (nn .Module ):
def __init__ (self , emb , depth , max_seq_len , num_tokens = 10000 , heads = 8 , bucket_size = 64 , n_hashes = 8 , ff_chunks = 100 ):
super ().__init__ ()
self .emb = emb
self .depth = depth
self .token_emb = nn .Embedding (num_tokens , emb )
self .pos_emb = nn .Embedding (max_seq_len , emb )
blocks = []
for _ in range (depth ):
f = WithLayerNorm (emb , LSHSelfAttention (emb , heads , bucket_size , n_hashes ))
g = Chunk (ff_chunks , WithLayerNorm (emb , FeedForward (emb )), dim = - 2 )
blocks .append (ReversibleBlock (f , g , dim = - 1 ))
self .layers = ReversibleSequence (nn .ModuleList (blocks ))
self .to_logits = nn .Linear (emb , num_tokens )
def forward (self , x ):
x = self .token_emb (x ) + self .pos_emb (torch .arange (0 , x .shape [1 ]))
x = torch .cat ([x , x ], dim = - 1 )
x = self .layers (x )
x = torch .stack (x .chunk (2 , dim = - 1 )).sum (dim = 0 )
return self .to_logits (x )
num_tokens = 10000
r = Reformer (
emb = 512 ,
depth = 12 ,
max_seq_len = 1024 ,
num_tokens = num_tokens ,
heads = 8 ,
bucket_size = 64 ,
n_hashes = 8 ,
ff_chunks = 200
)
x = torch .randint (0 , num_tokens , (1 , 1024 )).long ()
y = r (x )