Skip to content

Instantly share code, notes, and snippets.

@tokestermw
Forked from lucidrains/pytorch_reformer.py
Created January 22, 2020 07:29
Show Gist options
  • Save tokestermw/4187abe92db89b15f9a336a0d155334a to your computer and use it in GitHub Desktop.
Save tokestermw/4187abe92db89b15f9a336a0d155334a to your computer and use it in GitHub Desktop.

Revisions

  1. @lucidrains lucidrains revised this gist Jan 8, 2020. 1 changed file with 24 additions and 12 deletions.
    36 changes: 24 additions & 12 deletions pytorch_reformer.py
    Original file line number Diff line number Diff line change
    @@ -2,6 +2,23 @@
    import torch.nn as nn
    import torch.nn.functional as F

    # helpers

    def make_unit_length(x, epsilon=1e-6):
    norm = x.norm(p=2, dim=-1, keepdim=True)
    return x.div(norm + epsilon)

    def sort_key_val(t1, t2, dim=-1):
    values, indices = t1.sort(dim=dim)
    t2 = t2.expand_as(t1)
    return values, t2.gather(dim, indices)

    def batched_index_select(values, indices):
    b = values.shape[0]
    return values[torch.arange(0, b), indices.transpose(0, 1)].transpose(0, 1)

    # reversible net helper classes

    class ReversibleBlock(nn.Module):
    def __init__(self, f_block, g_block, dim = 1):
    super().__init__()
    @@ -84,18 +101,7 @@ def forward(self, x):
    x = _ReversibleModuleFunction.apply(x, self.reversible_blocks)
    return x

    def make_unit_length(x, epsilon=1e-6):
    norm = x.norm(p=2, dim=-1, keepdim=True)
    return x.div(norm + epsilon)

    def sort_key_val(t1, t2, dim=-1):
    values, indices = t1.sort(dim=dim)
    t2 = t2.expand_as(t1)
    return values, t2.gather(dim, indices)

    def batched_index_select(values, indices):
    b = values.shape[0]
    return values[torch.arange(0, b), indices.transpose(0, 1)].transpose(0, 1)
    # lsh attention

    class LSHAttention(nn.Module):
    def __init__( self,
    @@ -328,6 +334,8 @@ def split_heads(v):

    return self.unify_heads(out)

    # feedforward with chunking

    class FeedForward(nn.Module):
    def __init__(self, emb, mult = 4):
    super().__init__()
    @@ -362,6 +370,8 @@ def forward(self, x):
    chunks = x.chunk(self.chunks, dim = self.dim)
    return torch.cat([self.fn(c) for c in chunks], dim = self.dim)

    # reformer auto-regressive lm

    class Reformer(nn.Module):
    def __init__(self, emb, depth, max_seq_len, num_tokens = 10000, heads = 8, bucket_size = 64, n_hashes = 8, ff_chunks = 100):
    super().__init__()
    @@ -385,6 +395,8 @@ def forward(self, x):
    x = self.layers(x)
    x = torch.stack(x.chunk(2, dim=-1)).sum(dim=0)
    return self.to_logits(x)

    # testing

    num_tokens = 10000

  2. @lucidrains lucidrains created this gist Jan 8, 2020.
    403 changes: 403 additions & 0 deletions pytorch_reformer.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,403 @@
    import torch
    import torch.nn as nn
    import torch.nn.functional as F

    class ReversibleBlock(nn.Module):
    def __init__(self, f_block, g_block, dim = 1):
    super().__init__()
    self.dim = dim
    self.f_block = f_block
    self.g_block = g_block

    def forward(self, x):
    x1, x2 = torch.chunk(x, 2, dim=self.dim)
    y1, y2 = None, None
    with torch.no_grad():
    y1 = x1 + self.f_block(x2)
    y2 = x2 + self.g_block(y1)

    return torch.cat([y1, y2], dim=self.dim)

    def backward_pass(self, y, dy):
    y1, y2 = torch.chunk(y, 2, dim=self.dim)
    del y
    dy1, dy2 = torch.chunk(dy, 2, dim=self.dim)
    del dy

    y1.requires_grad = True
    y2.requires_grad = True

    with torch.enable_grad():
    gy1 = self.g_block(y1)
    gy1.backward(dy2)

    with torch.no_grad():
    x2 = y2 - gy1
    del y2, gy1

    dx1 = dy1 + y1.grad
    del dy1
    y1.grad = None

    with torch.enable_grad():
    x2.requires_grad = True
    fx2 = self.f_block(x2)
    fx2.backward(dx1)

    with torch.no_grad():
    x1 = y1 - fx2
    del y1, fx2

    dx2 = dy2 + x2.grad
    del dy2
    x2.grad = None

    x = torch.cat([x1, x2.detach()], dim=self.dim)
    dx = torch.cat([dx1, dx2], dim=self.dim)

    return x, dx

    class _ReversibleModuleFunction(torch.autograd.function.Function):
    @staticmethod
    def forward(ctx, x, reversible_blocks):
    for block in reversible_blocks:
    x = block(x)
    ctx.y = x.detach()
    ctx.reversible_blocks = reversible_blocks
    return x

    @staticmethod
    def backward(ctx, dy):
    y = ctx.y
    del ctx.y
    for i in range(len(ctx.reversible_blocks) - 1, -1, -1):
    y, dy = ctx.reversible_blocks[i].backward_pass(y, dy)
    del ctx.reversible_blocks
    return dy, None

    class ReversibleSequence(nn.Module):
    def __init__(self, reversible_blocks):
    super().__init__()
    self.reversible_blocks = reversible_blocks

    def forward(self, x):
    x = _ReversibleModuleFunction.apply(x, self.reversible_blocks)
    return x

    def make_unit_length(x, epsilon=1e-6):
    norm = x.norm(p=2, dim=-1, keepdim=True)
    return x.div(norm + epsilon)

    def sort_key_val(t1, t2, dim=-1):
    values, indices = t1.sort(dim=dim)
    t2 = t2.expand_as(t1)
    return values, t2.gather(dim, indices)

    def batched_index_select(values, indices):
    b = values.shape[0]
    return values[torch.arange(0, b), indices.transpose(0, 1)].transpose(0, 1)

    class LSHAttention(nn.Module):
    def __init__( self,
    dropout = 0.,
    bucket_size = 64,
    n_hashes = 8,
    allow_duplicate_attention = False,
    attend_across_buckets = False,
    rehash_each_round = True,
    drop_for_hash_rate = 0.0):
    super().__init__()

    if dropout >= 1.0:
    raise ValueError('Dropout rates must be lower than 1.')

    self.dropout = nn.Dropout(dropout)
    self.dropout_for_hash = nn.Dropout(drop_for_hash_rate)

    assert rehash_each_round or allow_duplicate_attention, (
    'The setting {allow_duplicate_attention=False, rehash_each_round=False}'
    ' is not implemented.')

    self.n_hashes = n_hashes
    self.bucket_size = bucket_size

    self._allow_duplicate_attention = allow_duplicate_attention
    self._attend_across_buckets = attend_across_buckets
    self._rehash_each_round = rehash_each_round

    def _sample_rotation(self, shape, vecs):
    device = vecs.device
    return torch.randn(shape, device=device)

    def hash_vectors(self, n_buckets, vecs):
    batch_size = vecs.shape[0]
    device = vecs.device

    # See https://arxiv.org/pdf/1509.02897.pdf
    # We sample a different random rotation for each round of hashing to
    # decrease the probability of hash misses.
    assert n_buckets % 2 == 0

    rot_size = n_buckets

    rotations_shape = (
    vecs.shape[-1],
    self.n_hashes if self._rehash_each_round else 1,
    rot_size // 2)

    random_rotations = self._sample_rotation(rotations_shape, vecs)

    dropped_vecs = self.dropout_for_hash(vecs)
    rotated_vecs = torch.einsum('btf,fhi->bhti', dropped_vecs, random_rotations)

    if self._rehash_each_round:
    rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1)
    buckets = torch.argmax(rotated_vecs, axis=-1)
    # buckets is now (self.n_hashes, seqlen). Next we add offsets so that
    # bucket numbers from different hashing rounds don't overlap.
    offsets = torch.arange(self.n_hashes, device=device)
    offsets = torch.reshape(offsets * n_buckets, (1, -1, 1))
    buckets = torch.reshape(buckets + offsets, (batch_size, -1,))
    else:
    assert not self._factorize_hash
    rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1)
    # In this configuration, we map each item to the top self.n_hashes buckets
    rotated_vecs = torch.squeeze(rotated_vecs, 0)
    bucket_range = torch.arange(0, rotated_vecs.shape[-1], device=device)
    bucket_range = torch.reshape(bucket_range, (1, -1))
    bucket_range = bucket_range.expand_as(rotated_vecs.shape)

    _, buckets = sort_key_val(rotated_vecs, bucket_range, dim=-1)
    buckets = buckets[:, -self.n_hashes:]

    h, *_ = buckets.shape
    buckets = torch.reshape(buckets.permute((*_, h)), (-1,))

    return buckets

    def forward(self, qk, v):
    batch_size, seqlen, _ = qk.shape
    device = qk.device

    n_buckets = seqlen // self.bucket_size
    n_bins = n_buckets

    buckets = self.hash_vectors(n_buckets, qk)
    # We use the same vector as both a query and a key.
    assert int(buckets.shape[1]) == self.n_hashes * seqlen

    ticker = torch.arange(0, self.n_hashes * seqlen, device=device).unsqueeze(0)
    buckets_and_t = seqlen * buckets + (ticker % seqlen)
    buckets_and_t = buckets_and_t.detach()

    # Hash-based sort ("s" at the start of variable names means "sorted")
    sbuckets_and_t, sticker = sort_key_val(buckets_and_t, ticker, dim=-1)
    _, undo_sort = sort_key_val(sticker, ticker, dim=-1)

    sbuckets_and_t = sbuckets_and_t.detach()
    sticker = sticker.detach()
    undo_sort = undo_sort.detach()

    st = (sticker % seqlen)
    sqk = batched_index_select(qk, st)
    sv = batched_index_select(v, st)

    # Split off a "bin" axis so that attention only occurs within chunks.
    bq_t = bkv_t = torch.reshape(st, (batch_size, self.n_hashes * n_bins, -1))
    bqk = torch.reshape(sqk, (batch_size, self.n_hashes * n_bins, -1, sqk.shape[-1]))
    bv = torch.reshape(sv, (batch_size, self.n_hashes * n_bins, -1, sv.shape[-1]))
    bq_buckets = bkv_buckets = torch.reshape(sbuckets_and_t // seqlen, (batch_size, self.n_hashes * n_bins, -1))

    # Hashing operates on unit-length vectors. Unnormalized query vectors are
    # fine because they effectively provide a learnable temperature for the
    # attention softmax, but normalizing keys is needed so that similarity for
    # the purposes of attention correctly corresponds to hash locality.
    bq = bqk
    bk = make_unit_length(bqk)

    # Allow each chunk to attend within itself, and also one chunk back. Chunk
    # boundaries might occur in the middle of a sequence of items from the
    # same bucket, so this increases the chances of attending to relevant items.
    def look_one_back(x):
    x_extra = torch.cat([x[:, -1:, ...], x[:, :-1, ...]], dim=1)
    return torch.cat([x, x_extra], dim=2)

    bk = look_one_back(bk)
    bv = look_one_back(bv)
    bkv_t = look_one_back(bkv_t)
    bkv_buckets = look_one_back(bkv_buckets)

    # Dot-product attention.
    dots = torch.einsum('bhie,bhje->bhij', bq, bk) / (bq.shape[-1] ** -0.5)

    # Causal masking
    mask = bq_t[:, :, :, None] < bkv_t[:, :, None, :]
    dots = dots - 1e9 * mask

    # Mask out attention to self except when no other targets are available.
    self_mask = bq_t[:, :, :, None] == bkv_t[:, :, None, :]
    dots = dots - 1e5 * self_mask

    # Mask out attention to other hash buckets.
    if not self._attend_across_buckets:
    bucket_mask = bq_buckets[:, :, :, None] != bkv_buckets[:, :, None, :]
    dots = dots - 1e7 * bucket_mask

    # Don't double-count query-key pairs across multiple rounds of hashing.
    # There are two possible strategies here. (1) The default is to count how
    # many times a query-key pair is repeated, and to lower its log-prob
    # correspondingly at each repetition. (2) When hard_k is set, the code
    # instead masks all but the first occurence of each query-key pair.
    if not self._allow_duplicate_attention:
    locs1 = undo_sort // bq_t.shape[-1]
    locs2 = (locs1 + 1) % (self.n_hashes * n_bins)
    if not self._attend_across_buckets:
    locs1 = buckets * (self.n_hashes * n_bins) + locs1
    locs2 = buckets * (self.n_hashes * n_bins) + locs2
    locs = torch.cat([
    torch.reshape(locs1, (batch_size, self.n_hashes, seqlen)),
    torch.reshape(locs2, (batch_size, self.n_hashes, seqlen)),
    ], 1).permute((0, 2, 1))

    slocs = batched_index_select(locs, st)
    b_locs = torch.reshape(slocs, (batch_size, self.n_hashes * n_bins, -1, 2 * self.n_hashes))

    b_locs1 = b_locs[:, :, :, None, :self.n_hashes]

    bq_locs = b_locs1.expand(b_locs.shape[:3] + (2, self.n_hashes))
    bq_locs = torch.reshape(bq_locs, b_locs.shape)
    bkv_locs = look_one_back(b_locs)

    dup_counts = (bq_locs[:, :, :, None, :] == bkv_locs[:, :, None, :, :]).float().sum(dim=-1)
    dup_counts = dup_counts.detach()
    assert dup_counts.shape == dots.shape
    dots = dots - torch.log(dup_counts + 1e-9)

    # Softmax.
    dots_logsumexp = torch.logsumexp(dots, dim=-1, keepdim=True)
    dots = torch.exp(dots - dots_logsumexp)
    dots = self.dropout(dots)

    bo = torch.einsum('buij,buje->buie', dots, bv)
    so = torch.reshape(bo, (batch_size, -1, bo.shape[-1]))
    slogits = torch.reshape(dots_logsumexp, (batch_size, -1,))

    o = batched_index_select(so, undo_sort)
    _, logits = sort_key_val(sticker, slogits, dim=-1)

    if self.n_hashes == 1:
    out = o
    else:
    o = torch.reshape(o, (batch_size, self.n_hashes, seqlen, o.shape[-1]))
    logits = torch.reshape(logits, (batch_size, self.n_hashes, seqlen, 1))
    probs = torch.exp(logits - torch.logsumexp(logits, dim=1, keepdims=True))
    out = torch.sum(o * probs, dim=1)

    assert out.shape == v.shape
    return out

    class LSHSelfAttention(nn.Module):
    def __init__(self, emb, heads = 8, bucket_size = 64, n_hashes = 8, **kwargs):
    super().__init__()
    self.heads = heads

    self.toqk = nn.Linear(emb, emb * heads)
    self.tov = nn.Linear(emb, emb * heads)
    self.unify_heads = nn.Linear(emb * heads, emb)

    self.bucket_size = bucket_size
    self.lsh_attn = LSHAttention(bucket_size=bucket_size, **kwargs)

    def forward(self, x):
    b, t, e, h = *x.shape, self.heads
    assert t % self.bucket_size == 0, f'Sequence length needs to be divisible by target bucket size - {self.bucket_size}'

    qk = self.toqk(x)
    v = self.tov(x)

    def merge_heads(v):
    return v.view(b, t, h, e).transpose(1, 2).reshape(b * h, t, e)

    def split_heads(v):
    return v.view(b, h, t, e).transpose(1, 2).contiguous()

    qk = merge_heads(qk)
    v = merge_heads(v)
    attn_out = self.lsh_attn(qk, v)
    out = split_heads(attn_out).view(b, t, h * e)

    return self.unify_heads(out)

    class FeedForward(nn.Module):
    def __init__(self, emb, mult = 4):
    super().__init__()
    self.emb = emb
    self.proj_in = nn.Linear(emb, emb * mult)
    self.proj_out = nn.Linear(emb * mult, emb)

    def forward(self, x):
    x = self.proj_in(x)
    x = F.gelu(x)
    x = self.proj_out(x)
    return x

    class WithLayerNorm(nn.Module):
    def __init__(self, emb, fn):
    super().__init__()
    self.emb = emb
    self.norm = nn.LayerNorm(emb)
    self.fn = fn
    def forward(self, x):
    x = self.norm(x)
    return self.fn(x)

    class Chunk(nn.Module):
    def __init__(self, chunks, fn, dim = -1):
    super().__init__()
    self.dim = dim
    self.chunks = chunks
    self.fn = fn

    def forward(self, x):
    chunks = x.chunk(self.chunks, dim = self.dim)
    return torch.cat([self.fn(c) for c in chunks], dim = self.dim)

    class Reformer(nn.Module):
    def __init__(self, emb, depth, max_seq_len, num_tokens = 10000, heads = 8, bucket_size = 64, n_hashes = 8, ff_chunks = 100):
    super().__init__()
    self.emb = emb
    self.depth = depth
    self.token_emb = nn.Embedding(num_tokens, emb)
    self.pos_emb = nn.Embedding(max_seq_len, emb)

    blocks = []
    for _ in range(depth):
    f = WithLayerNorm(emb, LSHSelfAttention(emb, heads, bucket_size, n_hashes))
    g = Chunk(ff_chunks, WithLayerNorm(emb, FeedForward(emb)), dim = -2)
    blocks.append(ReversibleBlock(f, g, dim=-1))

    self.layers = ReversibleSequence(nn.ModuleList(blocks))
    self.to_logits = nn.Linear(emb, num_tokens)

    def forward(self, x):
    x = self.token_emb(x) + self.pos_emb(torch.arange(0, x.shape[1]))
    x = torch.cat([x, x], dim = -1)
    x = self.layers(x)
    x = torch.stack(x.chunk(2, dim=-1)).sum(dim=0)
    return self.to_logits(x)

    num_tokens = 10000

    r = Reformer(
    emb = 512,
    depth = 12,
    max_seq_len = 1024,
    num_tokens= num_tokens,
    heads = 8,
    bucket_size = 64,
    n_hashes = 8,
    ff_chunks = 200
    )

    x = torch.randint(0, num_tokens, (1, 1024)).long()
    y = r(x)