Skip to content

Instantly share code, notes, and snippets.

@proger
Last active October 8, 2024 04:13
Show Gist options
  • Select an option

  • Save proger/0a04b2168f1110636c720ba204b5ac2d to your computer and use it in GitHub Desktop.

Select an option

Save proger/0a04b2168f1110636c720ba204b5ac2d to your computer and use it in GitHub Desktop.

Revisions

  1. proger revised this gist Jul 16, 2024. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion deltanet.py
    Original file line number Diff line number Diff line change
    @@ -452,7 +452,7 @@ def backward(ctx, d_y):


    def test_delta():
    NH, T, D = 2, 16, 3
    NH, T, D = 1, 64, 16
    q1, k1, v1, beta1 = make_example(NH, T, D)

    y0 = forward_loop(q1, k1, v1, beta1)
  2. proger revised this gist Jul 16, 2024. 1 changed file with 150 additions and 121 deletions.
    271 changes: 150 additions & 121 deletions deltanet.py
    Original file line number Diff line number Diff line change
    @@ -1,13 +1,8 @@
    """
    DeltaNet implementation reference for Accelerated Scan. DeltaNet performs efficient management of a large fixed-sized memory.
    For a simple single chunk version see `forward_simple`.
    It computes decayed values by a little bit of recurrence and then applies linear attention (`decay_values`).
    `forward_chunkwise` is inspired by Yang 2024. It applies single chunk version pointwise and
    then performs chunk-level stitching.
    forward_loop is the reference implementation of the original recurrence.
    `forward` is inspired by Yang 2024. It applies single chunk version pointwise and then performs chunk-level stitching.
    `forward_loop` is the reference implementation of the original recurrence.
    References:
    [1] The WY Representation for Products of Householder Matrices (Bischof and Van Loan 1985)
    @@ -63,78 +58,82 @@ def fmt(r,c,tag):
    ]))


    def decay_values(q, k, v, beta):
    "decay values applying deltanet forgetting rules"
    def decay_values(q, k, v, beta, chunk_size=2):
    NH, T, D = shape(q, k, v, beta)
    C = T // chunk_size
    q_, k_, v_, beta_ = (
    q.view(NH*C, chunk_size, D), k.view(NH*C, chunk_size, D),
    v.view(NH*C, chunk_size, D), beta.view(NH*C, chunk_size)
    )

    beta_ = beta.unsqueeze(-1)
    w = beta_ * k.clone()
    u = beta_ * v.clone()
    K = einsum('nsd,ntd->nst', k, k) # (T,T) matrix
    # evaluate all chunks in parallel
    beta__ = beta_.unsqueeze(-1)
    w = beta__ * k_.clone()
    u = beta__ * v_.clone()
    K = einsum('nsd,ntd->nst', k_, k_) # (chunk_size,chunk_size) matrix

    for t in range(1,T):
    w[:, t] -= beta_[:, t] * einsum('nt,ntd->nd', K[:, :t, t], w[:, :t].clone())
    u[:, t] -= beta_[:, t] * einsum('nt,ntd->nd', K[:, :t, t], u[:, :t].clone())
    for t in range(1,chunk_size):
    w[:, t] -= beta__[:, t] * einsum('nt,ntd->nd', K[:, :t, t], w[:, :t].clone())
    u[:, t] -= beta__[:, t] * einsum('nt,ntd->nd', K[:, :t, t], u[:, :t].clone())

    # attend to decayed values
    qk = einsum("nsk,ntk->nst", q, k)
    qk = einsum("nsk,ntk->nst", q_, k_)
    qk.tril_()
    y = einsum("nst,ntj->nsj", qk, u)

    return w, u, y


    def forward_chunkwise(q, k, v, beta, chunk_size=2):
    def forward(q, k, v, beta, chunk_size=2):
    "decay values applying deltanet forgetting rules, then stitch chunks"
    NH, T, D = shape(q, k, v, beta)
    C = T // chunk_size
    q_, k_, v_, beta_ = (
    q.view(NH*C, chunk_size, D), k.view(NH*C, chunk_size, D),
    v.view(NH*C, chunk_size, D), beta.view(NH*C, chunk_size)
    )

    # evaluate all chunks in parallel
    w, u, y = decay_values(q_, k_, v_, beta_)
    w, u, y = decay_values(q, k, v, beta, chunk_size=chunk_size)

    # stitch chunks sequentially
    q_ = q.view(NH, C, chunk_size, D)
    k_ = k.view(NH, C, chunk_size, D)
    u = u.view(NH, C, chunk_size, D)
    w = w.view(NH, C, chunk_size, D)
    y = y.view(NH, C, chunk_size, D)

    # materialize the state for the leading chunk
    state = einsum('ntv,ntk->nvk', u[:, 0], k_[:, 0])

    print(y, 'y')
    kc = k_[:, 0]
    uc = u[:, 0]

    state = u.new_zeros(NH, D, D)

    y_deltas = [u.new_zeros(NH, chunk_size, D)]
    for c in range(1, C):
    if c == 1:
    # early on this is cheaper than a state query
    kw = einsum('nsk,ntk->nst', k_[:, c-1], w[:, c]) # TDT
    u_old = einsum('nsv,nst->ntv', u[:, c-1], kw) # TDT
    state = state + einsum('ntv,ntk->nvk', uc, kc)

    kq = einsum('nsk,ntk->nst', k_[:, c-1], q_[:, c]) # TDT
    y_cur = einsum('nsv,nst->ntv', u[:, c-1], kq) # TTD
    else:
    u_old = einsum('nvk,ntk->ntv', state, w[:, c]) # DDT
    y_cur = einsum('nvk,nsk->nsv', state, q_[:, c]) # DDT
    wc = w[:, c] # load w
    uc = einsum('ntk,nvk->ntv', wc, state) # DDT

    qc = q_[:, c] # load q
    kc = k_[:, c] # load k

    # attend to old values
    qk = einsum("nsi,nti->nst", q_[:, c], k_[: ,c]) # TDT
    qk = einsum("nsi,nti->nst", qc, kc) # TDT
    qk = qk.tril()

    y_prev = einsum("nst,ntv->nsv", qk, u_old) # TTD
    yc = y[:, c].clone() # load y

    y_deltas.append(y_cur - y_prev)
    y_prev = einsum("nst,ntv->nsv", qk, uc) # TTD
    yc = yc - y_prev

    state_delta = einsum('ntv,ntk->nvk', u[:, c] - u_old, k_[:, c])
    state = state + state_delta
    y_cur = einsum('nsk,nvk->nsv', qc, state) # DDT
    yc = yc + y_cur

    y_delta = torch.stack(y_deltas, dim=1)
    y[:, c] = yc # store

    u1 = u[:, c] # load u
    uc = u1 - uc

    w = w.view(NH, T, D)
    u = u.view(NH, T, D)
    y = y.view(NH, T, D) + y_delta.view(NH, T, D)
    y = y.view(NH, T, D)

    return w, u, y

    @@ -186,8 +185,15 @@ def make_example(NH, T, D, device='cpu', dtype=torch.float32):


    @no_grad()
    def decay_values_backward(d_out_w, d_out_u, d_out_y, q, k, v, beta):
    NH, T, D = shape(q, k, v, beta)
    def backward(d_out_w_long, d_out_u_long, d_out_y_long, q_long, k_long, v_long, beta_long, chunk_size=2):
    NH, T, D = shape(q_long, k_long, v_long, beta_long)

    C = T // chunk_size
    q, k, v, beta, d_out_y = (
    q_long.view(NH*C, chunk_size, D), k_long.view(NH*C, chunk_size, D),
    v_long.view(NH*C, chunk_size, D), beta_long.view(NH*C, chunk_size),
    d_out_y_long.view(NH*C, chunk_size, D)
    )

    #
    # allocations
    @@ -198,21 +204,21 @@ def decay_values_backward(d_out_w, d_out_u, d_out_y, q, k, v, beta):
    k = k.clone() # load k
    v = v.clone() # load v
    beta = beta.clone() # load beta
    d_out_w = d_out_w.clone() # ntk
    d_out_y = d_out_y.clone() # ntv
    #d_out_w = d_out_w.clone() # ntk # placeholders
    #d_out_y = d_out_y.clone() # ntv # placeholders

    w = k.new_zeros(NH, T, D) # ntk
    u = v.new_zeros(NH, T, D) # ntw
    w = k.new_zeros(NH*C, chunk_size, D) # ntk
    u = v.new_zeros(NH*C, chunk_size, D) # ntw
    w_bases = w.clone() # ntk
    u_bases = u.clone() # ntw

    bk = einsum('nt,ntk->ntk', beta, k)

    bKl = k.new_zeros(NH, T, T)
    tt = k.new_zeros(NH, T, T)
    bKl = k.new_zeros(NH*C, chunk_size, chunk_size)
    tt = k.new_zeros(NH*C, chunk_size, chunk_size)

    d_k = k.new_zeros(NH, T, D) # nsk
    tk = k.new_zeros(NH, T, D) # ntk
    d_k = k.new_zeros(NH*C, chunk_size, D) # nsk
    tk = k.new_zeros(NH*C, chunk_size, D) # ntk

    #
    # forward
    @@ -225,7 +231,7 @@ def decay_values_backward(d_out_w, d_out_u, d_out_y, q, k, v, beta):
    u_bases = v
    v = einsum('nt,ntw->ntw', beta, v)

    for t in range(T):
    for t in range(chunk_size):
    tk = einsum('nts,nsk->ntk', bKl, w) # matmul for the sake of one row
    w[:, t] = bk[:, t, :] - tk[:, t, :]
    tk = einsum('nts,nsw->ntw', bKl, u) # matmul for the sake of one row
    @@ -234,6 +240,17 @@ def decay_values_backward(d_out_w, d_out_u, d_out_y, q, k, v, beta):
    w.clone() # store w
    u.clone() # store u

    #
    # stitch_backward
    #
    w_long = w.view(NH, T, D)
    u_long = u.view(NH, T, D)
    d_q_1_long, d_k_1_long, d_out_w_long, d_out_u_long = stitch_backward(d_out_y_long, q_long, k_long, w_long, u_long, C, chunk_size)

    d_out_w, d_out_u = (
    d_out_w_long.view(NH*C, chunk_size, D), d_out_u_long.view(NH*C, chunk_size, D)
    )

    w_bases = einsum('nts,nsk->ntk', tt, w)
    w_bases = k - w_bases
    v = einsum('nts,nsw->ntw', tt, u)
    @@ -264,7 +281,7 @@ def decay_values_backward(d_out_w, d_out_u, d_out_y, q, k, v, beta):

    d_k.zero_()

    for t in range(T-1,-1,-1):
    for t in range(chunk_size-1,-1,-1):
    # d_k
    tt = einsum('njw,ntw->njt', w, d_out_w) # matmul for the sake of one column t
    tt[:, t:, :] = 0
    @@ -311,132 +328,143 @@ def decay_values_backward(d_out_w, d_out_u, d_out_y, q, k, v, beta):
    beta += einsum('ntv->nt', u_bases)
    d_beta = beta.clone() # store

    return d_q, d_k, d_v, d_beta
    d_q_long = d_q.view(NH, T, D) + d_q_1_long
    d_k_long = d_k.view(NH, T, D) + d_k_1_long
    d_v_long = d_v.view(NH, T, D)
    d_beta_long = d_beta.view(NH, T)

    return d_q_long, d_k_long, d_v_long, d_beta_long


    def stitch_backward(d_y_delta, q, k, w, u, C, chunk_size):
    NH, T, D = shape(q, k, None, None)

    # outputs
    d_q_ = q.new_zeros(NH, C, chunk_size, D)
    d_k_ = k.new_zeros(NH, C, chunk_size, D)
    d_w = w.new_zeros(NH, C, chunk_size, D)
    d_u = u.new_zeros(NH, C, chunk_size, D)

    # chunked inputs
    d_y_delta = d_y_delta.view(NH, C, chunk_size, D)
    q_ = q.view(NH, C, chunk_size, D)
    k_ = k.view(NH, C, chunk_size, D)
    u = u.view(NH, C, chunk_size, D).clone()
    w = w.view(NH, C, chunk_size, D)

    d_q_ = q.new_zeros(NH, C, chunk_size, D)
    d_k_ = k.new_zeros(NH, C, chunk_size, D)
    d_w = w.new_zeros(NH, C, chunk_size, D)
    d_u = u.new_zeros(NH, C, chunk_size, D)
    # shared memory copy
    u = u.view(NH, C, chunk_size, D).clone()

    state = w.new_zeros(NH, D, D)
    d_state = w.new_zeros(NH, D, D) # NHVK
    state_delta = w.new_zeros(NH, D, D) # NHVK # can this be float32?

    qk = k.new_zeros(NH, chunk_size, C)
    tk = k.new_zeros(NH, chunk_size, D)

    # materialize the state for the leading chunk
    state = einsum('ntv,ntk->nvk', u[:, 0], k_[:, 0])

    # stitch forward
    for c in range(1, C):
    state_decay = einsum('nvk,ntk->ntv', state, w[:, c])
    u[:, c] = u[:, c] - state_decay
    tk = einsum('nvk,ntk->ntv', state, w[:, c])
    u[:, c] = u[:, c] - tk
    state_delta = einsum('ntv,ntk->nvk', u[:, c], k_[:, c])
    if c < C-1:
    state = state + state_delta
    state = state + state_delta # walk the state forwards

    # from now on, u's are decayed

    # stitch backward
    for c in range(C-1, 0, -1):
    d_u[:, c] = einsum('nvk,ntk->ntv', d_state, k_[:, c]) # state_add

    if c < C-1:
    state_delta = einsum('ntv,ntk->nvk', u[:, c], k_[:, c])
    state = state - state_delta # uncompute the state
    state_decay = einsum('nvk,ntk->ntv', state, w[:, c])
    state = state - state_delta # uncompute the state backwards
    tk = einsum('nvk,ntk->ntv', state, w[:, c]) # state_decay

    d_q1 = einsum('nsv,nvk->nsk', d_y_delta[:, c], state) # prev_output
    d_y_delta_c = d_y_delta[:, c]
    d_y_delta_c = -d_y_delta_c # neg

    # causal_attend_backward for delta
    d_out_att = -d_y_delta[:, c]
    d_out_state_decays = einsum('nsv,ntv->nst', d_out_att, state_decay)
    d_out_state_decays.tril_()
    d_q2 = einsum('nst,ntk->nsk', d_out_state_decays, k_[:, c])
    d_k1 = einsum('nst,nsk->ntk', d_out_state_decays, q_[:, c])
    d_q_[:, c] = d_q1 + d_q2
    # d_q, d_k
    qk = einsum('nsv,ntv->nst', d_y_delta_c, tk)
    qk.tril_()

    d_k2 = einsum('nvk,ntv->ntk', d_state, u[:, c]) # state_add
    d_k_[:, c] = d_k1 + d_k2
    # d_q
    tk = einsum('nst,ntk->nsk', qk, k_[:, c]) # causal_attend_backward for delta
    tk.sub_(einsum('nsv,nvk->nsk', d_y_delta_c, state)) # prev_output
    d_q_[:, c] = tk

    d_state1 = einsum('nsv,nsk->nvk', d_y_delta[:, c], q_[:, c]) # prev_output
    # d_k
    tk = einsum('nst,nsk->ntk', qk, q_[:, c])
    if c < C-1:
    tk.add_(einsum('nvk,ntv->ntk', d_state, u[:, c])) # state_add
    else:
    # d_state is zero
    pass
    d_k_[:, c] = tk

    # d_u
    if c < C-1:
    d_u[:, c] = einsum('nvk,ntk->ntv', d_state, k_[:, c]) # state_add
    else:
    # d_state is zero
    pass

    # d_state_decays
    qk = einsum('nsk,ntk->nst', q_[:, c], k_[:, c])
    qk.tril_()
    d_state_decays1 = einsum('nsv,nst->ntv', d_out_att, qk)
    d_state_decays = einsum('nsv,nst->ntv', d_y_delta_c, qk)
    if c < C-1:
    d_state_decays.sub_(einsum('nvk,ntk->ntv', d_state, k_[:, c])) # state_add

    d_state_decays2 = einsum('nvk,ntk->ntv', -d_state, k_[:, c]) # state_add
    d_state_decays = d_state_decays1 + d_state_decays2
    d_w[:, c] = einsum('ntv,nvk->ntk', d_state_decays, state) # state_decays
    # d_w
    tk = einsum('ntv,nvk->ntk', d_state_decays, state)
    d_w[:, c] = tk # state_decays

    d_state2 = einsum('ntv,ntk->nvk', d_state_decays, w[:, c]) # state_decays
    d_state = d_state + d_state1 + d_state2
    # backpropagate through time
    d_state.sub_(einsum('nsv,nsk->nvk', d_y_delta_c, q_[:, c])) # prev_output
    d_state.add_(einsum('ntv,ntk->nvk', d_state_decays, w[:, c])) # state_decays

    d_u[:, 0] = einsum('nvk,ntk->ntv', d_state, k_[:, 0]) # state_add
    d_k_[:, 0] = einsum('nvk,ntv->ntk', d_state, u[:, 0]) # state_add
    tk = einsum('nvk,ntk->ntv', d_state, k_[:, 0])
    d_u[:, 0] = tk # state_add
    tk = einsum('nvk,ntv->ntk', d_state, u[:, 0])
    d_k_[:, 0] = tk # state_add

    return d_q_.view(NH, T, D), d_k_.view(NH, T, D), d_w.view(NH, T, D), d_u.view(NH, T, D)


    class DeltaChunkwise(torch.autograd.Function):
    class Delta(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v, beta, chunk_size):
    w, u, y = forward_chunkwise(q, k, v, beta, chunk_size)
    ctx.save_for_backward(q, k, v, beta, w, u)
    w, u, y = forward(q, k, v, beta, chunk_size)
    ctx.save_for_backward(q, k, v, beta)
    ctx.chunk_size = chunk_size
    return y

    @staticmethod
    def backward(ctx, d_y):
    q, k, v, beta, w, u = ctx.saved_tensors
    q, k, v, beta = ctx.saved_tensors
    NH, T, D = shape(q, k, v, beta)
    chunk_size = ctx.chunk_size
    C = T // chunk_size

    q_, k_, v_, beta_ = (
    q.view(NH*C, chunk_size, D), k.view(NH*C, chunk_size, D),
    v.view(NH*C, chunk_size, D), beta.view(NH*C, chunk_size)
    )

    d_q_1, d_k_1, d_w, d_u = stitch_backward(d_y, q, k, w, u, C=C, chunk_size=chunk_size)

    d_w = d_w.view(NH*C, chunk_size, D)
    d_u = d_u.view(NH*C, chunk_size, D)
    d_y = d_y.view(NH*C, chunk_size, D)
    u = u.view(NH*C, chunk_size, D)

    d_q_2, d_k_2, d_v_, d_beta_ = decay_values_backward(d_w, d_u, d_y, q_, k_, v_, beta_)

    d_q_1 = d_q_1.view(NH, T, D)
    d_q_2 = d_q_2.view(NH, C, chunk_size, D)
    d_q_2 = d_q_2.reshape(NH, T, D)
    d_q = d_q_1 + d_q_2

    d_k_2 = d_k_2.reshape(NH, T, D)
    d_k = d_k_1 + d_k_2
    d_v = d_v_.reshape(NH, T, D)
    d_beta = d_beta_.view(NH, T)
    d_w = k.new_zeros(NH, T, D)
    d_u = v.new_zeros(NH, T, D)

    d_q, d_k, d_v, d_beta = backward(d_w, d_u, d_y, q, k, v, beta, chunk_size=ctx.chunk_size)
    return d_q, d_k, d_v, d_beta, None


    def test_delta_chunkwise_backward():
    def test_delta():
    NH, T, D = 2, 16, 3
    q1, k1, v1, beta1 = make_example(NH, T, D)

    y0 = forward_loop(q1, k1, v1, beta1)
    chunk_size = 8

    w1, u1, y1 = forward_chunkwise(q1, k1, v1, beta1, chunk_size=2)
    w1, u1, y1 = forward(q1, k1, v1, beta1, chunk_size=chunk_size)
    (y1 - torch.ones_like(y1).detach()).pow(2).mean().backward()

    assert allclose(y0, y1, atol=1e-5), 'y1 is wrong'

    q, k, v, beta = make_example(NH, T, D)
    y = DeltaChunkwise.apply(q, k, v, beta, 2)
    y = Delta.apply(q, k, v, beta, chunk_size)
    (y - torch.ones_like(y).detach()).pow(2).mean().backward()

    assert allclose(y1, y, atol=1e-5), 'y is wrong'
    @@ -452,4 +480,5 @@ def test_delta_chunkwise_backward():
    assert allclose(v1.grad, v.grad, atol=1e-5), 'v.grad is wrong'


    test_delta_chunkwise_backward()
    if __name__ == '__main__':
    test_delta()
  3. proger revised this gist Jul 15, 2024. 1 changed file with 26 additions and 24 deletions.
    50 changes: 26 additions & 24 deletions deltanet.py
    Original file line number Diff line number Diff line change
    @@ -336,46 +336,48 @@ def stitch_backward(d_y_delta, q, k, w, u, C, chunk_size):
    state_decay = einsum('nvk,ntk->ntv', state, w[:, c])
    u[:, c] = u[:, c] - state_decay
    state_delta = einsum('ntv,ntk->nvk', u[:, c], k_[:, c])
    state = state + state_delta
    if c < C-1:
    state = state + state_delta

    # from now on, u's are decayed

    # stitch backward
    for c in range(C-1, -1, -1):
    for c in range(C-1, 0, -1):
    d_u[:, c] = einsum('nvk,ntk->ntv', d_state, k_[:, c]) # state_add

    if c == 0:
    d_k_[:, c] = einsum('nvk,ntv->ntk', d_state, u[:, c]) # state_add
    else:
    if c < C-1:
    state_delta = einsum('ntv,ntk->nvk', u[:, c], k_[:, c])
    state = state - state_delta # uncompute the state
    state_decay = einsum('nvk,ntk->ntv', state, w[:, c])

    d_q1 = einsum('nsv,nvk->nsk', d_y_delta[:, c], state) # prev_output
    d_q1 = einsum('nsv,nvk->nsk', d_y_delta[:, c], state) # prev_output

    # causal_attend_backward for delta
    d_out_att = -d_y_delta[:, c]
    d_out_state_decays = einsum('nsv,ntv->nst', d_out_att, state_decay)
    d_out_state_decays.tril_()
    d_q2 = einsum('nst,ntk->nsk', d_out_state_decays, k_[:, c])
    d_k1 = einsum('nst,nsk->ntk', d_out_state_decays, q_[:, c])
    d_q_[:, c] = d_q1 + d_q2

    # causal_attend_backward for delta
    d_out_att = -d_y_delta[:, c]
    d_out_state_decays = einsum('nsv,ntv->nst', d_out_att, state_decay)
    d_out_state_decays.tril_()
    d_q2 = einsum('nst,ntk->nsk', d_out_state_decays, k_[:, c])
    d_k1 = einsum('nst,nsk->ntk', d_out_state_decays, q_[:, c])
    d_q_[:, c] = d_q1 + d_q2
    d_k2 = einsum('nvk,ntv->ntk', d_state, u[:, c]) # state_add
    d_k_[:, c] = d_k1 + d_k2

    d_k2 = einsum('nvk,ntv->ntk', d_state, u[:, c]) # state_add
    d_k_[:, c] = d_k1 + d_k2
    d_state1 = einsum('nsv,nsk->nvk', d_y_delta[:, c], q_[:, c]) # prev_output

    d_state1 = einsum('nsv,nsk->nvk', d_y_delta[:, c], q_[:, c]) # prev_output
    qk = einsum('nsk,ntk->nst', q_[:, c], k_[:, c])
    qk.tril_()
    d_state_decays1 = einsum('nsv,nst->ntv', d_out_att, qk)

    qk = einsum('nsk,ntk->nst', q_[:, c], k_[:, c])
    qk.tril_()
    d_state_decays1 = einsum('nsv,nst->ntv', d_out_att, qk)
    d_state_decays2 = einsum('nvk,ntk->ntv', -d_state, k_[:, c]) # state_add
    d_state_decays = d_state_decays1 + d_state_decays2
    d_w[:, c] = einsum('ntv,nvk->ntk', d_state_decays, state) # state_decays

    d_state_decays2 = einsum('nvk,ntk->ntv', -d_state, k_[:, c]) # state_add
    d_state_decays = d_state_decays1 + d_state_decays2
    d_w[:, c] = einsum('ntv,nvk->ntk', d_state_decays, state) # state_decays
    d_state2 = einsum('ntv,ntk->nvk', d_state_decays, w[:, c]) # state_decays
    d_state = d_state + d_state1 + d_state2

    d_state2 = einsum('ntv,ntk->nvk', d_state_decays, w[:, c]) # state_decays
    d_state = d_state + d_state1 + d_state2
    d_u[:, 0] = einsum('nvk,ntk->ntv', d_state, k_[:, 0]) # state_add
    d_k_[:, 0] = einsum('nvk,ntv->ntk', d_state, u[:, 0]) # state_add

    return d_q_.view(NH, T, D), d_k_.view(NH, T, D), d_w.view(NH, T, D), d_u.view(NH, T, D)

  4. proger revised this gist Jul 15, 2024. 1 changed file with 96 additions and 68 deletions.
    164 changes: 96 additions & 68 deletions deltanet.py
    Original file line number Diff line number Diff line change
    @@ -36,7 +36,31 @@
    import torch
    from torch import einsum, randn, allclose, stack, eye, manual_seed, no_grad, set_float32_matmul_precision, compile, arange

    set_float32_matmul_precision('high')
    #set_float32_matmul_precision('high')


    def tileprint(K, name='K'):
    "format matches tileprint in tk code so you can diff it"
    assert K.shape == (16, 16)
    for laneid in range(32):
    row_top = laneid // 4
    row_bottom = row_top + 8
    col_left = laneid % 4 * 2
    col_right = col_left + 8

    def fmt(r,c,tag):
    odd = "y" in tag
    if odd: # do not print r for odd rows because cuda printf silently runs out of function arguments
    return f"{name}[,{c:02}] {tag}={K[r,c]: .3f}"
    else:
    return f"{name}[{r:02},{c:02}] {tag}={K[r,c]: .3f}"

    print(f"lane={laneid:02}", " ".join([
    " ".join([fmt(row_top, col_left, "0x"), fmt(row_top, col_left+1, "0y")]),
    " ".join([fmt(row_bottom, col_left, "1x"), fmt(row_bottom, col_left+1, "1y")]),
    " ".join([fmt(row_top, col_right, "2x"), fmt(row_top, col_right+1, "2y")]),
    " ".join([fmt(row_bottom, col_right, "3x"), fmt(row_bottom, col_right+1, "3y")])
    ]))


    def decay_values(q, k, v, beta):
    @@ -53,9 +77,8 @@ def decay_values(q, k, v, beta):
    u[:, t] -= beta_[:, t] * einsum('nt,ntd->nd', K[:, :t, t], u[:, :t].clone())

    # attend to decayed values
    qk = einsum("nsi,nti->nst", q, k)
    mask = q.new_ones(T, T).tril()
    qk = einsum('nst,st->nst', qk, mask)
    qk = einsum("nsk,ntk->nst", q, k)
    qk.tril_()
    y = einsum("nst,ntj->nsj", qk, u)

    return w, u, y
    @@ -78,18 +101,29 @@ def forward_chunkwise(q, k, v, beta, chunk_size=2):
    u = u.view(NH, C, chunk_size, D)
    w = w.view(NH, C, chunk_size, D)

    # materialize the state for the leading chunk
    # materialize the state for the leading chunk
    state = einsum('ntv,ntk->nvk', u[:, 0], k_[:, 0])

    print(y, 'y')

    y_deltas = [u.new_zeros(NH, chunk_size, D)]
    for c in range(1, C):
    u_old = einsum('nvk,ntk->ntv', state, w[:, c])
    y_cur = einsum('nvk,nsk->nsv', state, q_[:, c])

    if c == 1:
    # early on this is cheaper than a state query
    kw = einsum('nsk,ntk->nst', k_[:, c-1], w[:, c]) # TDT
    u_old = einsum('nsv,nst->ntv', u[:, c-1], kw) # TDT

    kq = einsum('nsk,ntk->nst', k_[:, c-1], q_[:, c]) # TDT
    y_cur = einsum('nsv,nst->ntv', u[:, c-1], kq) # TTD
    else:
    u_old = einsum('nvk,ntk->ntv', state, w[:, c]) # DDT
    y_cur = einsum('nvk,nsk->nsv', state, q_[:, c]) # DDT

    # attend to old values
    qk = einsum("nsi,nti->nst", q_[:, c], k_[: ,c])
    qk = einsum("nsi,nti->nst", q_[:, c], k_[: ,c]) # TDT
    qk = qk.tril()
    y_prev = einsum("nst,ntj->nsj", qk, u_old)

    y_prev = einsum("nst,ntv->nsv", qk, u_old) # TTD

    y_deltas.append(y_cur - y_prev)

    @@ -98,7 +132,11 @@ def forward_chunkwise(q, k, v, beta, chunk_size=2):

    y_delta = torch.stack(y_deltas, dim=1)

    return y.view(NH, T, D) + y_delta.view(NH, T, D)
    w = w.view(NH, T, D)
    u = u.view(NH, T, D)
    y = y.view(NH, T, D) + y_delta.view(NH, T, D)

    return w, u, y


    def forward_loop(q, k, v, beta):
    @@ -134,19 +172,18 @@ def shape(q, k, v, beta=None):
    return NH, T, D


    def make_example(NH, T, D):
    def make_example(NH, T, D, device='cpu', dtype=torch.float32):
    manual_seed(0)
    q = randn(NH, T, D) / D**0.5
    q = randn(NH, T, D, device=device, dtype=dtype) / D**0.5
    q.requires_grad_()
    k = randn(NH, T, D) / D**0.5
    k = randn(NH, T, D, device=device, dtype=dtype) / D**0.5
    k.requires_grad_()
    v = randn(NH, T, D) / D**0.5
    v = randn(NH, T, D, device=device, dtype=dtype) / D**0.5
    v.requires_grad_()
    beta = randn(NH, T).sigmoid()
    beta = randn(NH, T, device=device, dtype=dtype).sigmoid()
    beta.requires_grad_()
    return q, k, v, beta

    #%%

    @no_grad()
    def decay_values_backward(d_out_w, d_out_u, d_out_y, q, k, v, beta):
    @@ -162,6 +199,7 @@ def decay_values_backward(d_out_w, d_out_u, d_out_y, q, k, v, beta):
    v = v.clone() # load v
    beta = beta.clone() # load beta
    d_out_w = d_out_w.clone() # ntk
    d_out_y = d_out_y.clone() # ntv

    w = k.new_zeros(NH, T, D) # ntk
    u = v.new_zeros(NH, T, D) # ntw
    @@ -218,7 +256,7 @@ def decay_values_backward(d_out_w, d_out_u, d_out_y, q, k, v, beta):

    v.zero_() # reuse register space of v for d_out_u
    d_out_u = d_out_u.clone() # load ntw
    d_out_u += einsum('nsv,nst->ntv', d_out_y, tt)
    d_out_u += einsum('nst,nsv->ntv', tt, d_out_y)

    #
    # backward for d_k, d_v, d_beta
    @@ -275,93 +313,84 @@ def decay_values_backward(d_out_w, d_out_u, d_out_y, q, k, v, beta):

    return d_q, d_k, d_v, d_beta

    # %%


    def stitch_backward(d_y_delta, q, k, w, u, C, chunk_size):
    NH, T, D = shape(q, k, None, None)

    d_y_delta = d_y_delta.view(NH, C, chunk_size, D)
    q_ = q.view(NH, C, chunk_size, D)
    k_ = k.view(NH, C, chunk_size, D)
    u = u.view(NH, C, chunk_size, D)
    u = u.view(NH, C, chunk_size, D).clone()
    w = w.view(NH, C, chunk_size, D)

    d_q_ = q.new_zeros(NH, C, chunk_size, D)
    d_k_ = k.new_zeros(NH, C, chunk_size, D)
    d_w = w.new_zeros(NH, C, chunk_size, D)
    d_u = u.new_zeros(NH, C, chunk_size, D)
    d_state = w.new_zeros(NH, D, D) # NHVK
    y_delta = u.new_zeros(NH, C, chunk_size, D) # leading chunk has zero delta

    # storing all states for BPTT
    states = k.new_zeros(NH, C, D, D) # NHCVK
    # materialize the state for the leading chunk
    states[:, 0] = einsum('ntv,ntk->nvk', u[:, 0], k_[:, 0])
    state = einsum('ntv,ntk->nvk', u[:, 0], k_[:, 0])

    # stitch forward
    for c in range(1, C):
    u_old = einsum('nvk,ntk->ntv', states[:, c-1], w[:, c])
    y_cur = einsum('nvk,nsk->nsv', states[:, c-1], q_[:, c])

    # attend to old values
    qk = einsum("nsi,nti->nst", q_[:, c], k_[: ,c])
    qk = qk.tril()
    y_prev = einsum("nst,ntj->nsj", qk, u_old)
    y_delta[:, c] = y_cur - y_prev
    state_decay = einsum('nvk,ntk->ntv', state, w[:, c])
    u[:, c] = u[:, c] - state_decay
    state_delta = einsum('ntv,ntk->nvk', u[:, c], k_[:, c])
    state = state + state_delta

    state_delta = einsum('ntv,ntk->nvk', u[:, c] - u_old, k_[:, c])
    states[:, c] = states[:, c-1] + state_delta
    # from now on, u's are decayed

    # stitch backward
    for c in range(C-1, -1, -1):
    d_u[:, c] = einsum('nvk,ntk->ntv', d_state, k_[:, c]) # state_add

    if c == 0:
    prev_state = torch.zeros_like(d_state)
    d_k_[:, c] = einsum('nvk,ntv->ntk', d_state, u[:, c]) # state_add
    else:
    prev_state = states[:, c-1]
    state_delta = einsum('ntv,ntk->nvk', u[:, c], k_[:, c])
    state = state - state_delta # uncompute the state
    state_decay = einsum('nvk,ntk->ntv', state, w[:, c])

    state_decays = einsum('nvk,ntk->ntv', prev_state, w[:, c])
    d_q1 = einsum('nsv,nvk->nsk', d_y_delta[:, c], state) # prev_output

    d_q1 = einsum('nsv,nvk->nsk', d_y_delta[:, c], prev_state) # prev_output
    d_state1 = einsum('nsv,nsk->nvk', d_y_delta[:, c], q_[:, c]) # prev_output
    # causal_attend_backward for delta
    d_out_att = -d_y_delta[:, c]
    d_out_state_decays = einsum('nsv,ntv->nst', d_out_att, state_decay)
    d_out_state_decays.tril_()
    d_q2 = einsum('nst,ntk->nsk', d_out_state_decays, k_[:, c])
    d_k1 = einsum('nst,nsk->ntk', d_out_state_decays, q_[:, c])
    d_q_[:, c] = d_q1 + d_q2

    # causal_attend_backward for delta
    mask = q.new_ones(T, T).tril()
    d_out_att = -d_y_delta[:, c]
    d_out_state_decays = einsum('nsv,ntv->nst', d_out_att, state_decays)
    d_out_state_decays.tril_()
    d_q2 = einsum('nst,ntk->nsk', d_out_state_decays, k_[:, c])
    d_k1 = einsum('nst,nsk->ntk', d_out_state_decays, q_[:, c])
    qk = einsum('nsk,ntk->nst', q_[:, c], k_[:, c])
    qk.tril_()
    d_state_decays1 = einsum('nsv,nst->ntv', d_out_att, qk)
    d_k2 = einsum('nvk,ntv->ntk', d_state, u[:, c]) # state_add
    d_k_[:, c] = d_k1 + d_k2

    d_k2 = einsum('nvk,ntv->ntk', d_state, u[:, c] - state_decays) # state_add
    d_u[:, c] = einsum('nvk,ntk->ntv', d_state, k_[:, c]) # state_add
    d_state_decays2 = einsum('nvk,ntk->ntv', -d_state, k_[:, c]) # state_add
    d_state1 = einsum('nsv,nsk->nvk', d_y_delta[:, c], q_[:, c]) # prev_output

    d_state_decays = d_state_decays1 + d_state_decays2
    d_w[:, c] = einsum('ntv,nvk->ntk', d_state_decays, prev_state) # state_decays
    d_state2 = einsum('ntv,ntk->nvk', d_state_decays, w[:, c]) # state_decays
    qk = einsum('nsk,ntk->nst', q_[:, c], k_[:, c])
    qk.tril_()
    d_state_decays1 = einsum('nsv,nst->ntv', d_out_att, qk)

    d_state = d_state + d_state1 + d_state2
    d_q_[:, c] = d_q1 + d_q2
    d_k_[:, c] = d_k1 + d_k2
    d_state_decays2 = einsum('nvk,ntk->ntv', -d_state, k_[:, c]) # state_add
    d_state_decays = d_state_decays1 + d_state_decays2
    d_w[:, c] = einsum('ntv,nvk->ntk', d_state_decays, state) # state_decays

    d_state2 = einsum('ntv,ntk->nvk', d_state_decays, w[:, c]) # state_decays
    d_state = d_state + d_state1 + d_state2

    return d_q_.view(NH, T, D), d_k_.view(NH, T, D), d_w.view(NH, T, D), d_u.view(NH, T, D)


    class DeltaChunkwise(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v, beta, chunk_size):
    y = forward_chunkwise(q, k, v, beta, chunk_size)
    ctx.save_for_backward(q, k, v, beta)
    w, u, y = forward_chunkwise(q, k, v, beta, chunk_size)
    ctx.save_for_backward(q, k, v, beta, w, u)
    ctx.chunk_size = chunk_size
    return y

    @staticmethod
    def backward(ctx, d_y):
    q, k, v, beta = ctx.saved_tensors
    q, k, v, beta, w, u = ctx.saved_tensors
    NH, T, D = shape(q, k, v, beta)
    chunk_size = ctx.chunk_size
    C = T // chunk_size
    @@ -371,8 +400,6 @@ def backward(ctx, d_y):
    v.view(NH*C, chunk_size, D), beta.view(NH*C, chunk_size)
    )

    w, u, y = decay_values(q_, k_, v_, beta_)

    d_q_1, d_k_1, d_w, d_u = stitch_backward(d_y, q, k, w, u, C=C, chunk_size=chunk_size)

    d_w = d_w.view(NH*C, chunk_size, D)
    @@ -396,19 +423,20 @@ def backward(ctx, d_y):


    def test_delta_chunkwise_backward():
    NH, T, D = 2, 16, 2
    NH, T, D = 2, 16, 3
    q1, k1, v1, beta1 = make_example(NH, T, D)

    y0 = forward_loop(q1, k1, v1, beta1)

    y1 = forward_chunkwise(q1, k1, v1, beta1, chunk_size=2)
    w1, u1, y1 = forward_chunkwise(q1, k1, v1, beta1, chunk_size=2)
    (y1 - torch.ones_like(y1).detach()).pow(2).mean().backward()

    assert allclose(y0, y1, atol=1e-5), 'y1 is wrong'

    q, k, v, beta = make_example(NH, T, D)
    y = DeltaChunkwise.apply(q, k, v, beta, 2)
    (y - torch.ones_like(y).detach()).pow(2).mean().backward()

    assert allclose(y0, y1, atol=1e-5), 'y1 is wrong'
    assert allclose(y1, y, atol=1e-5), 'y is wrong'

    # print(beta1.grad - beta.grad, 'beta.grad diff')
  5. proger revised this gist Jul 12, 2024. 1 changed file with 116 additions and 344 deletions.
    460 changes: 116 additions & 344 deletions deltanet.py
    Original file line number Diff line number Diff line change
    @@ -2,13 +2,12 @@
    DeltaNet implementation reference for Accelerated Scan. DeltaNet performs efficient management of a large fixed-sized memory.
    For a simple single chunk version see `forward_simple`.
    It computes decayed values by a little bit of recurrence (`decay_values`)
    and then applies linear attention (`causal_attend`).
    It computes decayed values by a little bit of recurrence and then applies linear attention (`decay_values`).
    `forward_chunkwise` is inspired by Yang 2024. It applies single chunk version pointwise and
    then performs chunk-level stitching.
    forward_ogloop and forward_scanloop are reference implementations of straightforward recurrences.
    forward_loop is the reference implementation of the original recurrence.
    References:
    [1] The WY Representation for Products of Householder Matrices (Bischof and Van Loan 1985)
    @@ -40,9 +39,9 @@
    set_float32_matmul_precision('high')


    def decay_values(k, v, beta):
    def decay_values(q, k, v, beta):
    "decay values applying deltanet forgetting rules"
    NH, T, D = shape(None, k, v, beta)
    NH, T, D = shape(q, k, v, beta)

    beta_ = beta.unsqueeze(-1)
    w = beta_ * k.clone()
    @@ -53,21 +52,13 @@ def decay_values(k, v, beta):
    w[:, t] -= beta_[:, t] * einsum('nt,ntd->nd', K[:, :t, t], w[:, :t].clone())
    u[:, t] -= beta_[:, t] * einsum('nt,ntd->nd', K[:, :t, t], u[:, :t].clone())

    return w, u


    def causal_attend(q, k, v, diagonal=0):
    "apply linear attention with a causal mask"
    NH, T, D = shape(q, k, v)
    mask = q.new_ones(T, T).tril(diagonal=diagonal)
    y = einsum("nsi,nti,st,ntj->nsj", q, k, mask, v)
    return y
    # attend to decayed values
    qk = einsum("nsi,nti->nst", q, k)
    mask = q.new_ones(T, T).tril()
    qk = einsum('nst,st->nst', qk, mask)
    y = einsum("nst,ntj->nsj", qk, u)


    def forward_simple(q, k, v, beta):
    "simple deltanet: linear attention to decayed values"
    w, u = decay_values(k, v, beta)
    return causal_attend(q, k, u)
    return w, u, y


    def forward_chunkwise(q, k, v, beta, chunk_size=2):
    @@ -79,18 +70,9 @@ def forward_chunkwise(q, k, v, beta, chunk_size=2):
    )

    # evaluate all chunks in parallel
    w, u = decay_values(k_, v_, beta_)
    y = causal_attend(q_, k_, u)
    w, u, y = decay_values(q_, k_, v_, beta_)

    # stitch chunks sequentially
    y_delta, _ = stitch_forward(q, k, w, u, C=C, chunk_size=chunk_size)
    return y.view(NH, T, D) + y_delta.view(NH, T, D)


    def stitch_forward(q, k, w, u, C, chunk_size):
    "stitch chunks sequentially"
    NH, T, D = shape(q, k, None, None)

    q_ = q.view(NH, C, chunk_size, D)
    k_ = k.view(NH, C, chunk_size, D)
    u = u.view(NH, C, chunk_size, D)
    @@ -99,28 +81,27 @@ def stitch_forward(q, k, w, u, C, chunk_size):
    # materialize the state for the leading chunk
    state = einsum('ntv,ntk->nvk', u[:, 0], k_[:, 0])

    deltas = [u.new_zeros(NH, chunk_size, D)]
    y_deltas = [u.new_zeros(NH, chunk_size, D)]
    for c in range(1, C):
    y_delta1, state = stitch1_forward(state, q_[:, c], k_[:, c], w[:, c], u[:, c])
    deltas.append(y_delta1)
    u_old = einsum('nvk,ntk->ntv', state, w[:, c])
    y_cur = einsum('nvk,nsk->nsv', state, q_[:, c])

    # attend to old values
    qk = einsum("nsi,nti->nst", q_[:, c], k_[: ,c])
    qk = qk.tril()
    y_prev = einsum("nst,ntj->nsj", qk, u_old)

    y_delta = torch.stack(deltas, dim=1)
    y_deltas.append(y_cur - y_prev)

    return y_delta, state
    state_delta = einsum('ntv,ntk->nvk', u[:, c] - u_old, k_[:, c])
    state = state + state_delta

    y_delta = torch.stack(y_deltas, dim=1)

    def stitch1_forward(state, q, k, w, u):
    state_decays = einsum('nvk,ntk->ntv', state, w)
    state_add = einsum('ntv,ntk->nvk', u - state_decays, k)

    delta = causal_attend(q, k, state_decays)
    prev_output = einsum('nvk,nsk->nsv', state, q)
    y = prev_output - delta

    return y, state + state_add
    return y.view(NH, T, D) + y_delta.view(NH, T, D)


    def forward_ogloop(q, k, v, beta):
    def forward_loop(q, k, v, beta):
    "reference: w_t = w_{t-1} + beta_t (v_t - w_t k_t) k_t"
    NH, T, D = shape(q, k, v, beta)

    @@ -142,30 +123,6 @@ def forward_ogloop(q, k, v, beta):
    return stack(y, dim=1)


    def forward_scanloop(q, k, v, beta):
    "reference via linear-time scan: w_t = w_{t-1} (I - beta_t k_t k_t.T) + beta v_t k_t.T"
    NH, T, D = shape(q, k, v, beta)

    w = k.new_zeros(NH, D, D)
    id = eye(D, device=w.device).expand(NH, D, D)
    y = []

    for t in range(T):
    q_ = q[:, t]
    k_ = k[:, t]
    v_ = v[:, t]
    beta_ = beta[:, t].unsqueeze(-1).unsqueeze(-1)
    beta_sqrt_ = beta_.squeeze(-1).sqrt()

    forget = id - einsum("ni,nj->nij", beta_sqrt_ * k_, beta_sqrt_ * k_)
    update = beta_ * einsum("ni,nj->nij", v_, k_)
    w = einsum("nik,nkj->nij", w, forget) + update

    y.append(einsum("nij,nj->ni", w, q_))

    return stack(y, dim=1)


    def shape(q, k, v, beta=None):
    NH, T, D = (q if q is not None else k).shape
    if q is not None:
    @@ -189,114 +146,18 @@ def make_example(NH, T, D):
    beta.requires_grad_()
    return q, k, v, beta


    def test_equal(atol=1e-6):
    NH, T, D = 2*3, 128, 16
    #NH, T, D = 1, 8, 3
    q, k, v, beta = make_example(NH, T, D)

    y1 = forward_ogloop(q, k, v, beta)
    y2 = forward_scanloop(q, k, v, beta)
    y3 = forward_simple(q, k, v, beta)

    assert allclose(y1, y2, atol=atol), (y1 - y2).abs().max()
    assert allclose(y1, y3, atol=atol), (y1 - y3).abs().max()

    for chunk_size in (1,2,4,8):
    y = forward_chunkwise(q, k, v, beta, chunk_size)
    assert allclose(y1, y, atol=atol), (y1 - y).abs().max()


    test_equal()

    #%%


    @no_grad()
    def attend_backward(d_out, q, k, v, g):
    d_q = einsum('nsv,ntk,ntv,nst->nsk', d_out, k, v, g)
    d_k = einsum('nsv,nsk,ntv,nst->ntk', d_out, q, v, g)
    d_v = einsum('nsv,nsk,ntk,nst->ntv', d_out, q, k, g)
    d_g = einsum('nsv,nsk,ntk,ntv,nst->nst', d_out, q, k, v, g)
    return d_q, d_k, d_v, d_g


    @no_grad()
    def causal_attend_backward(d_out, q, k, v, diagonal=0):
    NH, T, D = shape(q, k, v)
    mask = q.new_ones(T, T).tril(diagonal=diagonal).unsqueeze(0)
    d_q, q_k, d_v, _d_mask = attend_backward(d_out, q, k, v, mask)
    return d_q, q_k, d_v


    class CausalAttend(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v):
    ctx.save_for_backward(q, k, v)
    return causal_attend(q, k, v)

    @staticmethod
    def backward(ctx, d_out):
    q, k, v = ctx.saved_tensors
    return causal_attend_backward(d_out, q, k, v)


    def test_equal_attend_backward(atol=1e-5):
    NH, T, D = 1*1, 512, 64
    q, k, v, beta = make_example(NH, T, D)

    y = causal_attend(q, k, v)
    d_q, d_k, d_v = causal_attend_backward(torch.ones_like(y), q, k, v)
    y.sum().backward()

    assert allclose(q.grad, d_q, atol=atol), 'q.grad is wrong'
    assert allclose(k.grad, d_k, atol=atol), 'k.grad is wrong'
    assert allclose(v.grad, d_v, atol=atol), 'v.grad is wrong'

    ## TODO: test gates (g)
    # print((g_hook.grad - d_g).pow(2).mean(), 'error')
    # print((g_hook.grad - d_g).abs().max(), 'max abs error')
    # assert torch.allclose(g_hook.grad, d_g, atol=1e-1), 'g.grad is wrong'


    def test_equal_attend_backward2(atol=1e-5):
    NH, T, D = 1, 2, 2
    q1, k1, v1, beta1 = make_example(NH, T, D)
    y1 = causal_attend(q1, k1, v1)
    (y1 - torch.ones_like(y1)).pow(2).mean().backward()

    q, k, v, beta = make_example(NH, T, D)
    y = CausalAttend.apply(q, k, v)
    (y - torch.ones_like(y)).pow(2).mean().backward()

    # print(q1.grad - q.grad, 'q.grad diff')
    # print(k1.grad - k.grad, 'k.grad diff')
    # print(v1.grad - v.grad, 'v.grad diff')
    # print(k.grad, 'k.grad')
    # print(k1.grad, 'k1.grad')
    # print(v.grad, 'v.grad')
    # print(v1.grad, 'v1.grad')

    assert (q1.grad - q.grad).abs().max() < atol, 'q.grad is wrong'
    assert (k1.grad - k.grad).abs().max() < atol, 'k.grad is wrong'
    assert (v1.grad - v.grad).abs().max() < atol, 'v.grad is wrong'


    test_equal_attend_backward()
    test_equal_attend_backward2()


    #%%

    @no_grad()
    def decay_values_backward(d_out_w, d_out_u, k, v, beta):
    NH, T, D = shape(None, k, v, beta)
    def decay_values_backward(d_out_w, d_out_u, d_out_y, q, k, v, beta):
    NH, T, D = shape(q, k, v, beta)

    #
    # allocations
    #

    # this group is loaded from global memory
    q = q.clone() # load q
    k = k.clone() # load k
    v = v.clone() # load v
    beta = beta.clone() # load beta
    @@ -341,137 +202,82 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta):
    u_bases = u_bases - v

    #
    # backward for d_k, d_v, d_beta
    # causal_attend_backward for d_q, d_k_2, d_out_u
    #

    tt = einsum('nsv,ntv->nst', d_out_y, u)
    tt = tt.tril()
    d_q = einsum('nst,ntk->nsk', tt, k)
    d_q.clone() # store

    d_k_2 = einsum('nst,nsk->ntk', tt, q)
    d_k_2.clone() # store to shared memory?

    tt = einsum('nsk,ntk->nst', q, k)
    tt = tt.tril()

    v.zero_() # reuse register space of v for d_out_u
    d_out_u = d_out_u.clone() # ntw
    d_out_u = d_out_u.clone() # load ntw
    d_out_u += einsum('nsv,nst->ntv', d_out_y, tt)

    #
    # backward for d_k, d_v, d_beta
    #

    d_k.zero_()

    for t in range(T-1,-1,-1):
    # d_k
    tt = einsum('njw,ntw->njt', w, d_out_w) # matmul for the sake of one column t
    tt[:, t:, :] = 0
    tk = einsum('njt,njk->ntk', tt, k)

    tt = einsum('njv,ntv->njt', u, d_out_u) # matmul for the sake of one column t
    tt[:, t:, :] = 0
    tk += einsum('njt,njk->ntk', tt, k)

    d_k[:, t] += tk[:, t]

    # backpropagate through time, updating only remaining timestamps
    tt.zero_()
    tt[:, t] += bKl[:, t]
    tk = einsum('ntj,ntk->njk', tt, d_out_w)
    d_out_w -= tk
    d_out_w = d_out_w - tk
    tk = einsum('ntj,ntk->njk', tt, d_out_u)
    d_out_u -= tk
    d_out_u = d_out_u - tk

    d_k = d_out_w - d_k
    d_k = einsum('nt,ntk->ntk', beta, d_k)
    d_k = einsum('ntk,nt->ntk', d_k, beta)

    # decay w and u
    tt.zero_() # reuse K again
    tt += einsum('njw,ntw->ntj', w, d_out_w)
    tt += einsum('njw,ntw->ntj', u, d_out_u)
    tt = einsum('ntw,njw->ntj', d_out_w, w)
    tt += einsum('ntw,njw->ntj', d_out_u, u)
    tt.tril_(diagonal=-1)

    tk = einsum('ntj,ntk->njk', tt, bk)
    d_k -= tk
    d_k = d_k - tk
    d_k_2 = d_k_2.clone() # load from shared memory
    d_k = d_k_2 + d_k
    d_k = d_k.clone() # store

    # d_beta
    w_bases *= d_out_w
    u_bases *= d_out_u
    w_bases = einsum('ntk,ntk->ntk', w_bases, d_out_w)
    u_bases = einsum('ntw,ntw->ntw', u_bases, d_out_u)

    # d_v using d_out_u register
    d_out_u = einsum('nt,ntv->ntv', beta, d_out_u)
    d_v = d_out_u.clone() # store

    # continue d_beta reusing the beta register
    beta = einsum('ntk->nt', w_bases)
    beta += einsum('ntk->nt', u_bases)
    beta += einsum('ntv->nt', u_bases)
    d_beta = beta.clone() # store

    return d_k, d_v, d_beta


    class DecayValues(torch.autograd.Function):
    @staticmethod
    def forward(ctx, k, v, beta):
    w, u = decay_values(k, v, beta)
    ctx.save_for_backward(k, v, beta)
    return w, u

    @staticmethod
    def backward(ctx, d_out_w, d_out_u):
    k, v, beta = ctx.saved_tensors
    return decay_values_backward(d_out_w, d_out_u, k, v, beta)


    def test_equal_decay_values_backward():
    NH, T, D = 1, 16, 16

    q, k, v, beta = make_example(NH, T, D)
    w, u = decay_values(k, v, beta)
    (w + u - torch.ones_like(w)).pow(2).mean().backward()
    #(w - torch.ones_like(w)).pow(2).mean().backward()

    q1, k1, v1, beta1 = make_example(NH, T, D)
    w1, u1 = DecayValues.apply(k1, v1, beta1)
    (w1 + u1 - torch.ones_like(w1)).pow(2).mean().backward()
    #(w1 - torch.ones_like(w1)).pow(2).mean().backward()

    # print(v.grad, 'v.grad', v.grad.shape)
    # print(v1.grad, 'v1.grad')
    # print(v.grad - v1.grad, 'v diff')
    assert allclose(v.grad, v1.grad, atol=1e-5), 'v1_grad is wrong'

    # print(beta.grad, 'beta.grad du')
    # print(beta1.grad, 'beta1.grad du')
    assert allclose(beta.grad, beta1.grad, atol=1e-5), 'beta1.grad is wrong'

    # print(k.grad, 'k.grad du')
    # print(k1.grad, 'k1.grad du')
    # print(k.grad - k1.grad, 'diff du')
    assert allclose(k.grad, k1.grad, atol=1e-5), 'k1_grad is wrong'

    test_equal_decay_values_backward()
    return d_q, d_k, d_v, d_beta

    # %%


    def stitch1_backward(d_y, d_state, state, q, k, w, u):
    state_decays = einsum('nvk,ntk->ntv', state, w)

    d_q1 = einsum('nsv,nvk->nsk', d_y, state) # prev_output
    d_state1 = einsum('nsv,nsk->nvk', d_y, q) # prev_output

    d_q2, d_k1, d_state_decays1 = causal_attend_backward(-d_y, q, k, state_decays) # delta

    d_k2 = einsum('nvk,ntv->ntk', d_state, u - state_decays) # state_add
    d_u = einsum('nvk,ntk->ntv', d_state, k) # state_add
    d_state_decays2 = einsum('nvk,ntk->ntv', -d_state, k) # state_add

    d_state_decays = d_state_decays1 + d_state_decays2
    d_w = einsum('ntv,nvk->ntk', d_state_decays, state) # state_decays
    d_state2 = einsum('ntv,ntk->nvk', d_state_decays, w) # state_decays

    return d_state + d_state1 + d_state2, d_q1 + d_q2, d_k1 + d_k2, d_w, d_u


    class Stitch1(torch.autograd.Function):
    @staticmethod
    def forward(ctx, state, q, k, w, u):
    y, new_state = stitch1_forward(state, q, k, w, u)
    ctx.save_for_backward(state, q, k, w, u)
    return y, new_state

    @staticmethod
    def backward(ctx, d_y, d_state):
    state, q, k, w, u = ctx.saved_tensors
    return stitch1_backward(d_y, d_state, state, q, k, w, u)


    def stitch_backward(d_y_delta, q, k, w, u, C, chunk_size):
    NH, T, D = shape(q, k, None, None)

    @@ -493,92 +299,58 @@ def stitch_backward(d_y_delta, q, k, w, u, C, chunk_size):
    # materialize the state for the leading chunk
    states[:, 0] = einsum('ntv,ntk->nvk', u[:, 0], k_[:, 0])

    # stitch forward
    for c in range(1, C):
    y_delta[:, c], states[:, c] = stitch1_forward(states[:, c-1], q_[:, c], k_[:, c], w[:, c], u[:, c])
    u_old = einsum('nvk,ntk->ntv', states[:, c-1], w[:, c])
    y_cur = einsum('nvk,nsk->nsv', states[:, c-1], q_[:, c])

    # attend to old values
    qk = einsum("nsi,nti->nst", q_[:, c], k_[: ,c])
    qk = qk.tril()
    y_prev = einsum("nst,ntj->nsj", qk, u_old)
    y_delta[:, c] = y_cur - y_prev

    state_delta = einsum('ntv,ntk->nvk', u[:, c] - u_old, k_[:, c])
    states[:, c] = states[:, c-1] + state_delta

    # stitch backward
    for c in range(C-1, -1, -1):
    if c == 0:
    prev_state = torch.zeros_like(d_state)
    else:
    prev_state = states[:, c-1]

    state_decays = einsum('nvk,ntk->ntv', prev_state, w[:, c])

    d_q1 = einsum('nsv,nvk->nsk', d_y_delta[:, c], prev_state) # prev_output
    d_state1 = einsum('nsv,nsk->nvk', d_y_delta[:, c], q_[:, c]) # prev_output

    # causal_attend_backward for delta
    mask = q.new_ones(T, T).tril()
    d_out_att = -d_y_delta[:, c]
    d_out_state_decays = einsum('nsv,ntv->nst', d_out_att, state_decays)
    d_out_state_decays.tril_()
    d_q2 = einsum('nst,ntk->nsk', d_out_state_decays, k_[:, c])
    d_k1 = einsum('nst,nsk->ntk', d_out_state_decays, q_[:, c])
    qk = einsum('nsk,ntk->nst', q_[:, c], k_[:, c])
    qk.tril_()
    d_state_decays1 = einsum('nsv,nst->ntv', d_out_att, qk)

    d_k2 = einsum('nvk,ntv->ntk', d_state, u[:, c] - state_decays) # state_add
    d_u[:, c] = einsum('nvk,ntk->ntv', d_state, k_[:, c]) # state_add
    d_state_decays2 = einsum('nvk,ntk->ntv', -d_state, k_[:, c]) # state_add

    d_state_decays = d_state_decays1 + d_state_decays2
    d_w[:, c] = einsum('ntv,nvk->ntk', d_state_decays, prev_state) # state_decays
    d_state2 = einsum('ntv,ntk->nvk', d_state_decays, w[:, c]) # state_decays

    d_state = d_state + d_state1 + d_state2
    d_q_[:, c] = d_q1 + d_q2
    d_k_[:, c] = d_k1 + d_k2

    for c in range(C-1, 0, -1):
    (
    d_state, d_q_[:, c], d_k_[:, c], d_w[:, c], d_u[:, c]
    ) = stitch1_backward(d_y_delta[:, c], d_state, states[:, c-1], q_[:, c], k_[:, c], w[:, c], u[:, c])

    (
    d_state, d_q_[:, 0], d_k_[:, 0], d_w[:, 0], d_u[:, 0]
    ) = stitch1_backward(d_y_delta[:, 0], d_state, torch.zeros_like(d_state), q_[:, 0], k_[:, 0], w[:, 0], u[:, 0])

    return d_q_.view(NH, T, D), d_k_.view(NH, T, D), d_w.view(NH, T, D), d_u.view(NH, T, D)


    class Stitch(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, w, u, C, chunk_size):
    y_delta, state = stitch_forward(q, k, w, u, C, chunk_size)
    ctx.save_for_backward(q, k, w, u)
    ctx.C = C
    ctx.chunk_size = chunk_size
    return y_delta

    @staticmethod
    def backward(ctx, d_y_delta):
    q, k, w, u = ctx.saved_tensors
    return *stitch_backward(d_y_delta, q, k, w, u, ctx.C, ctx.chunk_size), None, None


    def test_stitch_all(atol=1e-5):
    NH, T, D = 1, 4, 2
    C, chunk_size = 2, 2
    q, k, v, beta = make_example(NH, T, D)
    w, u = decay_values(k, v, beta)

    q.requires_grad_()
    k.requires_grad_()
    v.requires_grad_()
    beta.requires_grad_()
    w.retain_grad()
    u.retain_grad()

    y, new_state = stitch_forward(q, k, w, u, C=C, chunk_size=chunk_size)
    loss = (y - torch.ones_like(y)).pow(2).mean()
    loss.backward()

    # print(q.grad, 'q.grad')
    # print(k.grad, 'k.grad')
    # print(v.grad, 'v.grad')
    # print(beta.grad, 'beta.grad')
    # print(w.grad, 'w.grad')
    # print(u.grad, 'u.grad')

    q1, k1, v1, beta1 = make_example(NH, T, D)
    w1, u1 = decay_values(k1, v1, beta1)

    q1.requires_grad_()
    k1.requires_grad_()
    v1.requires_grad_()
    beta1.requires_grad_()
    w1.retain_grad()
    u1.retain_grad()

    y1 = Stitch.apply(q1, k1, w1, u1, C, chunk_size)
    loss = (y1 - torch.ones_like(y1)).pow(2).mean()
    loss.backward()

    assert allclose(y, y1, atol=atol), 'y is wrong'

    assert allclose(u.grad, u1.grad, atol=atol), 'u.grad is wrong'
    assert allclose(v.grad, v1.grad, atol=atol), 'v.grad is wrong'
    # print(k.grad, 'k.grad')
    # print(k1.grad, 'k1.grad')
    assert allclose(k.grad, k1.grad, atol=atol), 'k.grad is wrong'
    assert allclose(q.grad, q1.grad, atol=atol), 'q.grad is wrong'
    assert allclose(beta.grad, beta1.grad, atol=atol), 'beta.grad is wrong'
    assert allclose(w.grad, w1.grad, atol=atol), 'w.grad is wrong'


    test_stitch_all()


    #%%


    class DeltaChunkwise(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v, beta, chunk_size):
    @@ -599,27 +371,24 @@ def backward(ctx, d_y):
    v.view(NH*C, chunk_size, D), beta.view(NH*C, chunk_size)
    )

    w, u = decay_values(k_, v_, beta_)
    w, u, y = decay_values(q_, k_, v_, beta_)

    d_q_1, d_k_1, d_w, d_u1 = stitch_backward(d_y, q, k, w, u, C=C, chunk_size=chunk_size)
    d_q_1, d_k_1, d_w, d_u = stitch_backward(d_y, q, k, w, u, C=C, chunk_size=chunk_size)

    d_w = d_w.view(NH*C, chunk_size, D)
    d_u1 = d_u1.view(NH*C, chunk_size, D)
    d_u = d_u.view(NH*C, chunk_size, D)
    d_y = d_y.view(NH*C, chunk_size, D)
    u = u.view(NH*C, chunk_size, D)

    d_q_2, d_k_2, d_u2 = causal_attend_backward(d_y, q_, k_, u)
    d_u = d_u1 + d_u2
    d_k_3, d_v_, d_beta_ = decay_values_backward(d_w, d_u, k_, v_, beta_)
    d_q_2, d_k_2, d_v_, d_beta_ = decay_values_backward(d_w, d_u, d_y, q_, k_, v_, beta_)

    d_q_1 = d_q_1.view(NH, T, D)
    d_q_2 = d_q_2.view(NH, C, chunk_size, D)
    d_q_2 = d_q_2.reshape(NH, T, D)
    d_q = d_q_1 + d_q_2

    d_k_2 = d_k_2.reshape(NH, T, D)
    d_k_3 = d_k_3.reshape(NH, T, D)
    d_k = d_k_1 + d_k_2 + d_k_3
    d_k = d_k_1 + d_k_2
    d_v = d_v_.reshape(NH, T, D)
    d_beta = d_beta_.view(NH, T)

    @@ -629,6 +398,8 @@ def backward(ctx, d_y):
    def test_delta_chunkwise_backward():
    NH, T, D = 2, 16, 2
    q1, k1, v1, beta1 = make_example(NH, T, D)

    y0 = forward_loop(q1, k1, v1, beta1)

    y1 = forward_chunkwise(q1, k1, v1, beta1, chunk_size=2)
    (y1 - torch.ones_like(y1).detach()).pow(2).mean().backward()
    @@ -637,6 +408,7 @@ def test_delta_chunkwise_backward():
    y = DeltaChunkwise.apply(q, k, v, beta, 2)
    (y - torch.ones_like(y).detach()).pow(2).mean().backward()

    assert allclose(y0, y1, atol=1e-5), 'y1 is wrong'
    assert allclose(y1, y, atol=1e-5), 'y is wrong'

    # print(beta1.grad - beta.grad, 'beta.grad diff')
  6. proger revised this gist Jul 11, 2024. 1 changed file with 32 additions and 30 deletions.
    62 changes: 32 additions & 30 deletions deltanet.py
    Original file line number Diff line number Diff line change
    @@ -301,53 +301,53 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta):
    v = v.clone() # load v
    beta = beta.clone() # load beta
    d_out_w = d_out_w.clone() # ntk
    d_out_u = d_out_u.clone() # ntw

    w = k.new_zeros(NH, T, D) # ntk
    u = v.new_zeros(NH, T, D) # ntw
    w_bases = w.clone() # ntk
    u_bases = u.clone() # ntw

    bk = einsum('nt,ntk->ntk', beta, k)
    bv = einsum('nt,ntw->ntw', beta, v)

    K = k.new_zeros(NH, T, T) # reused later twice in the backward pass
    bKl = k.new_zeros(NH, T, T)
    tt = k.new_zeros(NH, T, T)

    d_k = k.new_zeros(NH, T, D) # nsk

    tk = k.new_zeros(NH, T, D) # ntk

    tvec_reg = k.new_zeros(NH, T) # nt
    kvec_reg = k.new_zeros(NH, D) # nw

    #
    # forward
    #

    K = einsum('ntd,nsd->nts', k, k)
    K = K.tril(diagonal=-1) # make_causal(0); set_diagonal(0)
    bKl = einsum('nt,nts->nts', beta, K) # multiply each row of K by beta
    tt = einsum('ntk,nsk->nts', k, k)
    tt = tt.tril(diagonal=-1) # make_causal(0); set_diagonal(0)
    bKl = einsum('nt,nts->nts', beta, tt) # multiply each row of K by beta

    u_bases = v
    v = einsum('nt,ntw->ntw', beta, v)

    for t in range(T):
    c_w = einsum('nts,nsk->ntk', bKl, w) # matmul for the sake of one row
    w[:, t] = bk[:, t, :] - c_w[:, t, :]
    c_u = einsum('nts,nsw->ntw', bKl, u) # matmul for the sake of one row
    u[:, t] = bv[:, t, :] - c_u[:, t, :]
    tk = einsum('nts,nsk->ntk', bKl, w) # matmul for the sake of one row
    w[:, t] = bk[:, t, :] - tk[:, t, :]
    tk = einsum('nts,nsw->ntw', bKl, u) # matmul for the sake of one row
    u[:, t] = v[:, t, :] - tk[:, t, :]

    w.clone() # store w
    u.clone() # store u

    w_bases = einsum('nts,nsk->ntk', K, w)
    w_bases = einsum('nts,nsk->ntk', tt, w)
    w_bases = k - w_bases
    u_bases = einsum('nts,nsw->ntw', K, u)
    u_bases = v - u_bases
    v = einsum('nts,nsw->ntw', tt, u)
    u_bases = u_bases - v

    #
    # backward for d_k, d_v, d_beta
    #

    v.zero_() # reuse register space of v for d_out_u
    d_out_u = d_out_u.clone() # ntw
    d_k.zero_()

    for t in range(T-1,-1,-1):
    # d_k
    tt = einsum('njw,ntw->njt', w, d_out_w) # matmul for the sake of one column t
    @@ -359,37 +359,39 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta):
    d_k[:, t] += tk[:, t]

    # backpropagate through time, updating only remaining timestamps
    K.zero_() # reuse K
    K[:, t] += bKl[:, t]
    tk = einsum('ntj,ntk->njk', K, d_out_w)
    tt.zero_()
    tt[:, t] += bKl[:, t]
    tk = einsum('ntj,ntk->njk', tt, d_out_w)
    d_out_w -= tk
    tk = einsum('ntj,ntk->njk', K, d_out_u)
    tk = einsum('ntj,ntk->njk', tt, d_out_u)
    d_out_u -= tk

    d_k = d_out_w - d_k
    d_k = einsum('nt,ntk->ntk', beta, d_k)

    # decay w and u
    K.zero_() # reuse K again
    K += einsum('njw,ntw->ntj', w, d_out_w)
    K += einsum('njw,ntw->ntj', u, d_out_u)
    K.tril_(diagonal=-1)
    tt.zero_() # reuse K again
    tt += einsum('njw,ntw->ntj', w, d_out_w)
    tt += einsum('njw,ntw->ntj', u, d_out_u)
    tt.tril_(diagonal=-1)

    tk = einsum('ntj,ntk->njk', K, bk)
    tk = einsum('ntj,ntk->njk', tt, bk)
    d_k -= tk
    d_k = d_k.clone() # store

    # d_beta
    w_bases *= d_out_w
    tvec_reg = einsum('ntk->nt', w_bases)
    u_bases *= d_out_u
    tvec_reg += einsum('ntk->nt', u_bases)
    d_beta = tvec_reg.clone() # store

    # d_v
    # d_v using d_out_u register
    d_out_u = einsum('nt,ntv->ntv', beta, d_out_u)
    d_v = d_out_u.clone() # store

    # continue d_beta reusing the beta register
    beta = einsum('ntk->nt', w_bases)
    beta += einsum('ntk->nt', u_bases)
    d_beta = beta.clone() # store

    return d_k, d_v, d_beta


  7. proger revised this gist Jul 11, 2024. 1 changed file with 44 additions and 54 deletions.
    98 changes: 44 additions & 54 deletions deltanet.py
    Original file line number Diff line number Diff line change
    @@ -296,9 +296,12 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta):
    # allocations
    #

    # this group is loaded from global memory
    k = k.clone() # load k
    v = v.clone() # load v
    beta = beta.clone() # load beta
    d_out_w = d_out_w.clone() # ntk
    d_out_u = d_out_u.clone() # ntw

    w = k.new_zeros(NH, T, D) # ntk
    u = v.new_zeros(NH, T, D) # ntw
    @@ -308,93 +311,80 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta):
    bk = einsum('nt,ntk->ntk', beta, k)
    bv = einsum('nt,ntw->ntw', beta, v)

    K = k.new_zeros(NH, T, T) # reused later as t_regs_uw
    K = k.new_zeros(NH, T, T) # reused later twice in the backward pass
    bKl = k.new_zeros(NH, T, T)
    tt = k.new_zeros(NH, T, T)

    d_k = k.new_zeros(NH, T, D) # nsk

    d_out_w = d_out_w.clone() # ntk
    d_out_u = d_out_u.clone() # ntw

    tk = k.new_zeros(NH, T, D) # ntk

    t_reg = k.new_zeros(NH, T) # nt
    w_reg = k.new_zeros(NH, D) # nw
    tvec_reg = k.new_zeros(NH, T) # nt
    kvec_reg = k.new_zeros(NH, D) # nw

    #
    # forward
    #

    K = einsum('ntd,nsd->nts', k, k)
    K = K.tril(diagonal=-1) # make_causal(0); set_diagonal(0)
    bKl = einsum('nt,nts->nts', -beta, K) # multiply each row of K by beta
    bKl = einsum('nt,nts->nts', beta, K) # multiply each row of K by beta

    for t in range(T):
    c_w = einsum('nts,nsk->ntk', bKl, w)
    w[:, t] = bk[:, t, :] + c_w[:, t, :]
    c_u = einsum('nts,nsw->ntw', bKl, u)
    u[:, t] = bv[:, t, :] + c_u[:, t, :]
    c_w = einsum('nts,nsk->ntk', bKl, w) # matmul for the sake of one row
    w[:, t] = bk[:, t, :] - c_w[:, t, :]
    c_u = einsum('nts,nsw->ntw', bKl, u) # matmul for the sake of one row
    u[:, t] = bv[:, t, :] - c_u[:, t, :]

    w.clone() # store w
    u.clone() # store u

    w_bases = einsum('nts,nsk->ntk', K, w)
    w_bases = k - w_bases
    u_bases = einsum('nts,nsw->ntw', K, u)
    u_bases = v - u_bases

    w.clone() # store w
    u.clone() # store u

    #
    # backward for d_k, d_v, d_beta
    #

    t_regs_uw = K
    t_regs_uw.zero_()

    for t in range(T-1,-1,-1):
    # d_k
    w[:, t, :] = 0
    k[:, t, :] = 0
    u[:, t, :] = 0

    # wk
    t_reg = einsum('nw,njw->nj', d_out_w[:, t], w)
    w_reg = einsum('nj,njk->nk', t_reg, k)
    w_reg = d_out_w[:, t] - w_reg
    w_reg = einsum('n,nk->nk', beta[:,t], w_reg)
    d_k[:, t] += w_reg

    # uk
    t_reg = einsum('nw,njw->nj', d_out_u[:, t], u)
    w_reg = einsum('nj,njk->nk', t_reg, k)
    w_reg = einsum('n,nk->nk', beta[:,t], w_reg)
    d_k[:, t] -= w_reg

    # decay w
    w_reg = d_out_w[:, t]
    tk = einsum('nw,ntw->ntw', w_reg, w)
    t_reg = einsum('ntw->nt', tk)
    t_regs_uw[:, t] += t_reg

    # decay u
    w_reg = d_out_u[:, t]
    tk = einsum('nw,ntw->ntw', w_reg, u)
    t_reg = einsum('ntw->nt', tk)
    t_regs_uw[:, t] += t_reg

    # backpropagate through time
    d_out_w += einsum('nj,nk->njk', bKl[:, t, :], d_out_w[:, t])
    d_out_u += einsum('nj,nk->njk', bKl[:, t, :], d_out_u[:, t])

    tk = einsum('nst,nsk->ntk', t_regs_uw, bk)
    tt = einsum('njw,ntw->njt', w, d_out_w) # matmul for the sake of one column t
    tt[:, t:, :] = 0
    tk = einsum('njt,njk->ntk', tt, k)
    tt = einsum('njv,ntv->njt', u, d_out_u) # matmul for the sake of one column t
    tt[:, t:, :] = 0
    tk += einsum('njt,njk->ntk', tt, k)
    d_k[:, t] += tk[:, t]

    # backpropagate through time, updating only remaining timestamps
    K.zero_() # reuse K
    K[:, t] += bKl[:, t]
    tk = einsum('ntj,ntk->njk', K, d_out_w)
    d_out_w -= tk
    tk = einsum('ntj,ntk->njk', K, d_out_u)
    d_out_u -= tk

    d_k = d_out_w - d_k
    d_k = einsum('nt,ntk->ntk', beta, d_k)

    # decay w and u
    K.zero_() # reuse K again
    K += einsum('njw,ntw->ntj', w, d_out_w)
    K += einsum('njw,ntw->ntj', u, d_out_u)
    K.tril_(diagonal=-1)

    tk = einsum('ntj,ntk->njk', K, bk)
    d_k -= tk
    d_k = d_k.clone() # store

    # d_beta
    w_bases *= d_out_w
    t_reg = einsum('ntk->nt', w_bases)
    tvec_reg = einsum('ntk->nt', w_bases)
    u_bases *= d_out_u
    t_reg += einsum('ntk->nt', u_bases)
    d_beta = t_reg.clone() # store
    tvec_reg += einsum('ntk->nt', u_bases)
    d_beta = tvec_reg.clone() # store

    # d_v
    d_out_u = einsum('nt,ntv->ntv', beta, d_out_u)
  8. proger revised this gist Jul 10, 2024. 1 changed file with 8 additions and 5 deletions.
    13 changes: 8 additions & 5 deletions deltanet.py
    Original file line number Diff line number Diff line change
    @@ -308,7 +308,7 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta):
    bk = einsum('nt,ntk->ntk', beta, k)
    bv = einsum('nt,ntw->ntw', beta, v)

    K = k.new_zeros(NH, T, T)
    K = k.new_zeros(NH, T, T) # reused later as t_regs_uw
    bKl = k.new_zeros(NH, T, T)

    d_k = k.new_zeros(NH, T, D) # nsk
    @@ -347,6 +347,9 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta):
    # backward for d_k, d_v, d_beta
    #

    t_regs_uw = K
    t_regs_uw.zero_()

    for t in range(T-1,-1,-1):
    # d_k
    w[:, t, :] = 0
    @@ -370,20 +373,20 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta):
    w_reg = d_out_w[:, t]
    tk = einsum('nw,ntw->ntw', w_reg, w)
    t_reg = einsum('ntw->nt', tk)
    w_reg = bk[:, t]
    d_k -= einsum('nt,nk->ntk', t_reg, w_reg)
    t_regs_uw[:, t] += t_reg

    # decay u
    w_reg = d_out_u[:, t]
    tk = einsum('nw,ntw->ntw', w_reg, u)
    t_reg = einsum('ntw->nt', tk)
    w_reg = bk[:, t]
    d_k -= einsum('nt,nk->ntk', t_reg, w_reg)
    t_regs_uw[:, t] += t_reg

    # backpropagate through time
    d_out_w += einsum('nj,nk->njk', bKl[:, t, :], d_out_w[:, t])
    d_out_u += einsum('nj,nk->njk', bKl[:, t, :], d_out_u[:, t])

    tk = einsum('nst,nsk->ntk', t_regs_uw, bk)
    d_k -= tk
    d_k = d_k.clone() # store

    # d_beta
  9. proger revised this gist Jul 10, 2024. 1 changed file with 64 additions and 36 deletions.
    100 changes: 64 additions & 36 deletions deltanet.py
    Original file line number Diff line number Diff line change
    @@ -296,78 +296,106 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta):
    # allocations
    #

    eye = torch.eye(D, device=k.device, dtype=k.dtype)
    eye = eye.unsqueeze(0).expand(NH, D, D)
    k = k.clone() # load k
    v = v.clone() # load v
    beta = beta.clone() # load beta

    # recompute w and u TK-style
    w = k.new_zeros(NH, T, D) # ntk
    u = v.new_zeros(NH, T, D) # ntw
    w_bases = w.clone()
    u_bases = u.clone()
    w_bases = w.clone() # ntk
    u_bases = u.clone() # ntw

    bk = einsum('nt,ntk->ntk', beta, k)
    bv = einsum('nt,ntw->ntw', beta, v)

    K = einsum('ntd,nsd->nts', k, k) # (T,T) matrix
    K = K.tril(diagonal=-1) # make_causal(0); set_diagonal(0)
    bKl = einsum('nt,nts->nts', -beta, K) # multiply each row of K by beta
    K = k.new_zeros(NH, T, T)
    bKl = k.new_zeros(NH, T, T)

    d_k = k.new_zeros(NH, T, D) # nsk
    d_beta = beta.new_zeros(NH, T) # ns
    d_v = v.new_zeros(NH, T, D) # nsv
    d_out_w_backward = d_out_w #.clone() # ntk # doesn't seem to need a copy but why?
    d_out_u_backward = d_out_u.clone() # ntw

    d_out_w = d_out_w.clone() # ntk
    d_out_u = d_out_u.clone() # ntw

    tk = k.new_zeros(NH, T, D) # ntk

    t_reg = k.new_zeros(NH, T) # nt
    w_reg = k.new_zeros(NH, D) # nw

    #
    # forward
    #

    K = einsum('ntd,nsd->nts', k, k)
    K = K.tril(diagonal=-1) # make_causal(0); set_diagonal(0)
    bKl = einsum('nt,nts->nts', -beta, K) # multiply each row of K by beta

    for t in range(T):
    c_w = einsum('nts,nsk->ntk', bKl, w)
    w[:, t] = bk[:, t, :] + c_w[:, t, :]
    c_u = einsum('nts,nsw->ntw', bKl, u)
    u[:, t] = bv[:, t, :] + c_u[:, t, :]

    w_bases = k - einsum('nts,nsk->ntk', K, w)
    u_bases = v - einsum('nts,nsw->ntw', K, u)
    w_bases = einsum('nts,nsk->ntk', K, w)
    w_bases = k - w_bases
    u_bases = einsum('nts,nsw->ntw', K, u)
    u_bases = v - u_bases

    w0 = w.clone() # we will be mutating these, so store original w and u here
    u0 = u.clone()
    w.clone() # store w
    u.clone() # store u

    #
    # backward for d_k, d_v, d_beta
    #

    for t in range(T-1,-1,-1):
    # d_k
    w[:, t, :] = 0
    k[:, t, :] = 0
    u[:, t, :] = 0
    wk = einsum('njw,njk->nwk', w, k)
    wk = eye - wk
    wk = einsum('n,nwk->nwk', beta[:, t], wk)
    uk = einsum('njw,njk->nwk', u, k)
    uk = einsum('n,nwk->nwk', beta[:, t], uk)

    # d_k
    d_k[:, t] += einsum('nw,nwk->nk', d_out_w_backward[:, t], wk)
    d_k[:, t] -= einsum('nw,nwk->nk', d_out_u_backward[:, t], uk)

    decay_w = einsum('nw,nsw->ns', d_out_w_backward[:, t], w[:, :t])
    decay_u = einsum('nw,nsw->ns', d_out_u_backward[:, t], u[:, :t])

    d_k[:, :t] -= einsum('nk,ns->nsk', bk[:, t], decay_w)
    d_k[:, :t] -= einsum('nk,ns->nsk', bk[:, t], decay_u)
    # wk
    t_reg = einsum('nw,njw->nj', d_out_w[:, t], w)
    w_reg = einsum('nj,njk->nk', t_reg, k)
    w_reg = d_out_w[:, t] - w_reg
    w_reg = einsum('n,nk->nk', beta[:,t], w_reg)
    d_k[:, t] += w_reg

    # uk
    t_reg = einsum('nw,njw->nj', d_out_u[:, t], u)
    w_reg = einsum('nj,njk->nk', t_reg, k)
    w_reg = einsum('n,nk->nk', beta[:,t], w_reg)
    d_k[:, t] -= w_reg

    # decay w
    w_reg = d_out_w[:, t]
    tk = einsum('nw,ntw->ntw', w_reg, w)
    t_reg = einsum('ntw->nt', tk)
    w_reg = bk[:, t]
    d_k -= einsum('nt,nk->ntk', t_reg, w_reg)

    # decay u
    w_reg = d_out_u[:, t]
    tk = einsum('nw,ntw->ntw', w_reg, u)
    t_reg = einsum('ntw->nt', tk)
    w_reg = bk[:, t]
    d_k -= einsum('nt,nk->ntk', t_reg, w_reg)

    # backpropagate through time
    d_out_w_backward[:, :t] += einsum('nj,nk->njk', bKl[:, t, :t], d_out_w_backward[:, t])
    d_out_u_backward[:, :t] += einsum('nj,nk->njk', bKl[:, t, :t], d_out_u_backward[:, t])
    d_out_w += einsum('nj,nk->njk', bKl[:, t, :], d_out_w[:, t])
    d_out_u += einsum('nj,nk->njk', bKl[:, t, :], d_out_u[:, t])

    d_k = d_k.clone() # store

    # d_beta
    d_beta += einsum('ntk,ntk->nt', w_bases, d_out_w_backward)
    d_beta += einsum('ntk,ntk->nt', u_bases, d_out_u_backward)
    w_bases *= d_out_w
    t_reg = einsum('ntk->nt', w_bases)
    u_bases *= d_out_u
    t_reg += einsum('ntk->nt', u_bases)
    d_beta = t_reg.clone() # store

    # d_v
    d_v = einsum('nt,ntv->ntv', beta, d_out_u_backward)
    d_out_u = einsum('nt,ntv->ntv', beta, d_out_u)
    d_v = d_out_u.clone() # store

    return d_k, d_v, d_beta

    @@ -386,7 +414,7 @@ def backward(ctx, d_out_w, d_out_u):


    def test_equal_decay_values_backward():
    NH, T, D = 1, 16, 3
    NH, T, D = 1, 16, 16

    q, k, v, beta = make_example(NH, T, D)
    w, u = decay_values(k, v, beta)
  10. proger revised this gist Jul 10, 2024. 1 changed file with 23 additions and 11 deletions.
    34 changes: 23 additions & 11 deletions deltanet.py
    Original file line number Diff line number Diff line change
    @@ -292,9 +292,18 @@ def test_equal_attend_backward2(atol=1e-5):
    def decay_values_backward(d_out_w, d_out_u, k, v, beta):
    NH, T, D = shape(None, k, v, beta)

    #
    # allocations
    #

    eye = torch.eye(D, device=k.device, dtype=k.dtype)
    eye = eye.unsqueeze(0).expand(NH, D, D)

    # recompute w and u TK-style
    w = k.new_zeros(NH, T, D) # ntk
    u = v.new_zeros(NH, T, D) # ntw
    w_bases = w.clone()
    u_bases = u.clone()

    bk = einsum('nt,ntk->ntk', beta, k)
    bv = einsum('nt,ntw->ntw', beta, v)
    @@ -303,28 +312,31 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta):
    K = K.tril(diagonal=-1) # make_causal(0); set_diagonal(0)
    bKl = einsum('nt,nts->nts', -beta, K) # multiply each row of K by beta

    d_k = k.new_zeros(NH, T, D) # nsk
    d_beta = beta.new_zeros(NH, T) # ns
    d_v = v.new_zeros(NH, T, D) # nsv
    d_out_w_backward = d_out_w #.clone() # ntk # doesn't seem to need a copy but why?
    d_out_u_backward = d_out_u.clone() # ntw

    #
    # forward
    #

    for t in range(T):
    c_w = einsum('nts,nsk->ntk', bKl, w)
    w[:, t] = bk[:, t, :] + c_w[:, t, :]
    c_u = einsum('nts,nsw->ntw', bKl, u)
    u[:, t] = bv[:, t, :] + c_u[:, t, :]

    # compute gradients for d_k, d_v, d_beta
    d_k = k.new_zeros(NH, T, D) # nsk
    d_beta = beta.new_zeros(NH, T) # ns
    d_v = v.new_zeros(NH, T, D) # nsv

    eye = torch.eye(D, device=k.device, dtype=k.dtype)
    eye = eye.unsqueeze(0).expand(NH, D, D)

    w_bases = k - einsum('nts,nsk->ntk', K, w)
    u_bases = v - einsum('nts,nsw->ntw', K, u)

    w0 = w.clone() # we will be mutating these, but the kernel also returns the original w and u
    w0 = w.clone() # we will be mutating these, so store original w and u here
    u0 = u.clone()

    d_out_w_backward = d_out_w.clone() # ntk
    d_out_u_backward = d_out_u.clone() # ntw
    #
    # backward for d_k, d_v, d_beta
    #

    for t in range(T-1,-1,-1):
    w[:, t, :] = 0
  11. proger revised this gist Jul 10, 2024. 1 changed file with 7 additions and 7 deletions.
    14 changes: 7 additions & 7 deletions deltanet.py
    Original file line number Diff line number Diff line change
    @@ -346,17 +346,17 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta):
    d_k[:, :t] -= einsum('nk,ns->nsk', bk[:, t], decay_w)
    d_k[:, :t] -= einsum('nk,ns->nsk', bk[:, t], decay_u)

    # d_beta
    d_beta[:, t] += einsum('nk,nk->n', w_bases[:, t], d_out_w_backward[:, t])
    d_beta[:, t] += einsum('nk,nk->n', u_bases[:, t], d_out_u_backward[:, t])

    # d_v
    d_v[:, t] = einsum('n,nv->nv', beta[:, t], d_out_u_backward[:, t])

    # backpropagate through time
    d_out_w_backward[:, :t] += einsum('nj,nk->njk', bKl[:, t, :t], d_out_w_backward[:, t])
    d_out_u_backward[:, :t] += einsum('nj,nk->njk', bKl[:, t, :t], d_out_u_backward[:, t])

    # d_beta
    d_beta += einsum('ntk,ntk->nt', w_bases, d_out_w_backward)
    d_beta += einsum('ntk,ntk->nt', u_bases, d_out_u_backward)

    # d_v
    d_v = einsum('nt,ntv->ntv', beta, d_out_u_backward)

    return d_k, d_v, d_beta


  12. proger revised this gist Jul 10, 2024. 1 changed file with 34 additions and 26 deletions.
    60 changes: 34 additions & 26 deletions deltanet.py
    Original file line number Diff line number Diff line change
    @@ -320,34 +320,42 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta):
    w_bases = k - einsum('nts,nsk->ntk', K, w)
    u_bases = v - einsum('nts,nsw->ntw', K, u)

    w0 = w.clone() # we will be mutating these, but the kernel also returns the original w and u
    u0 = u.clone()

    d_out_w_backward = d_out_w.clone() # ntk
    d_out_u_backward = d_out_u.clone() # ntw

    for t in range(T-1,-1,-1):
    wk = einsum('njw,njk->nwk', w[:, :t], k[:, :t])
    w[:, t, :] = 0
    k[:, t, :] = 0
    u[:, t, :] = 0
    wk = einsum('njw,njk->nwk', w, k)
    wk = eye - wk
    uk = einsum('njw,njk->nwk', u[:, :t], k[:, :t])
    wk = einsum('n,nwk->nwk', beta[:, t], wk)
    uk = einsum('njw,njk->nwk', u, k)
    uk = einsum('n,nwk->nwk', beta[:, t], uk)

    # d_k
    wst = einsum('n,nwk->nwk', beta[:, t], wk)
    d_k[:, t] += einsum('nw,nwk->nk', d_out_w[:, t], wst)

    ust = einsum('n,nwk->nwk', beta[:, t], uk)
    d_k[:, t] -= einsum('nw,nwk->nk', d_out_u[:, t], ust)
    d_k[:, t] += einsum('nw,nwk->nk', d_out_w_backward[:, t], wk)
    d_k[:, t] -= einsum('nw,nwk->nk', d_out_u_backward[:, t], uk)

    decay_w = einsum('nw,nsw->ns', d_out_w[:, t], w[:, :t])
    decay_u = einsum('nw,nsw->ns', d_out_u[:, t], u[:, :t])
    decay_w = einsum('nw,nsw->ns', d_out_w_backward[:, t], w[:, :t])
    decay_u = einsum('nw,nsw->ns', d_out_u_backward[:, t], u[:, :t])

    d_k[:, :t] -= einsum('nk,ns->nsk', bk[:, t], decay_w)
    d_k[:, :t] -= einsum('nk,ns->nsk', bk[:, t], decay_u)

    # d_beta
    d_beta[:, t] += einsum('nk,nk->n', w_bases[:, t], d_out_w[:, t])
    d_beta[:, t] += einsum('nk,nk->n', u_bases[:, t], d_out_u[:, t])
    d_beta[:, t] += einsum('nk,nk->n', w_bases[:, t], d_out_w_backward[:, t])
    d_beta[:, t] += einsum('nk,nk->n', u_bases[:, t], d_out_u_backward[:, t])

    # d_v
    d_v[:, t] = einsum('n,nv->nv', beta[:, t], d_out_u[:, t])

    d_v[:, t] = einsum('n,nv->nv', beta[:, t], d_out_u_backward[:, t])
    # backpropagate through time
    d_out_w += einsum('nj,nk->njk', bKl[:, t, :], d_out_w[:, t])
    d_out_u += einsum('nj,nk->njk', bKl[:, t, :], d_out_u[:, t])
    d_out_w_backward[:, :t] += einsum('nj,nk->njk', bKl[:, t, :t], d_out_w_backward[:, t])
    d_out_u_backward[:, :t] += einsum('nj,nk->njk', bKl[:, t, :t], d_out_u_backward[:, t])

    return d_k, d_v, d_beta

    @@ -366,32 +374,32 @@ def backward(ctx, d_out_w, d_out_u):


    def test_equal_decay_values_backward():
    NH, T, D = 1, 4, 1
    NH, T, D = 1, 16, 3

    q, k, v, beta = make_example(NH, T, D)
    w, u = decay_values(k, v, beta)
    #(w + u - torch.ones_like(w)).pow(2).mean().backward()
    (w - torch.ones_like(w)).pow(2).mean().backward()
    (w + u - torch.ones_like(w)).pow(2).mean().backward()
    #(w - torch.ones_like(w)).pow(2).mean().backward()

    q1, k1, v1, beta1 = make_example(NH, T, D)
    w1, u1 = DecayValues.apply(k1, v1, beta1)
    #(w1 + u1 - torch.ones_like(w1)).pow(2).mean().backward()
    (w1 - torch.ones_like(w1)).pow(2).mean().backward()
    (w1 + u1 - torch.ones_like(w1)).pow(2).mean().backward()
    #(w1 - torch.ones_like(w1)).pow(2).mean().backward()

    # print(v.grad, 'v.grad', v.grad.shape)
    # print(v1.grad, 'v1.grad')
    # print(v.grad - v1.grad, 'v diff')
    # assert allclose(v.grad, v1.grad, atol=1e-5), 'v1_grad is wrong'

    print(k.grad, 'k.grad du')
    print(k1.grad, 'k1.grad du')
    print(k.grad - k1.grad, 'diff du')
    assert allclose(k.grad, k1.grad, atol=1e-5), 'k1_grad is wrong'
    assert allclose(v.grad, v1.grad, atol=1e-5), 'v1_grad is wrong'

    # print(beta.grad, 'beta.grad du')
    # print(beta1.grad, 'beta1.grad du')
    assert allclose(beta.grad, beta1.grad, atol=1e-5), 'beta1.grad is wrong'

    # print(k.grad, 'k.grad du')
    # print(k1.grad, 'k1.grad du')
    # print(k.grad - k1.grad, 'diff du')
    assert allclose(k.grad, k1.grad, atol=1e-5), 'k1_grad is wrong'

    test_equal_decay_values_backward()

    # %%
  13. proger revised this gist Jul 10, 2024. 1 changed file with 43 additions and 113 deletions.
    156 changes: 43 additions & 113 deletions deltanet.py
    Original file line number Diff line number Diff line change
    @@ -291,133 +291,63 @@ def test_equal_attend_backward2(atol=1e-5):
    @no_grad()
    def decay_values_backward(d_out_w, d_out_u, k, v, beta):
    NH, T, D = shape(None, k, v, beta)
    DV = D
    K = einsum('nsd,ntd->nst', k, k) # (T,T) matrix
    beta_ = beta.unsqueeze(-1)

    w = k.new_zeros(NH, T, D)
    u = v.new_zeros(NH, T, D)

    # recompute w: same as decay_values_clone
    """
    w_0 = b_0 k_0
    w_1 = b_1 k_1 - b_1 k_0^T k_1 w_0
    w_2 = b_2 k_2 - b_2 k_0^T k_2 w_0 - b_2 k_1^T k_2 w_1
    w_t = b_t k_t - b_t \sum_{s=0}^{t-1} k_s^T k_t w_s
    u_t = b_t v_t - b_t \sum_{s=0}^{t-1} k_s^T k_t u_s
    """
    w[:, 0] = beta_[:, 0] * k[:, 0]
    u[:, 0] = beta_[:, 0] * v[:, 0]
    for t in range(1,T):
    w[:, t] = beta_[:, t] * k.clone()[:, t] - beta_[:, t] * einsum('nsj,nj,nsd->nd', k[:, :t], k[:, t], w[:, :t].clone())
    u[:, t] = beta_[:, t] * v.clone()[:, t] - beta_[:, t] * einsum('nsj,nj,nsd->nd', k[:, :t], k[:, t], u[:, :t].clone())

    """
    b0 k0 b1 k1 b2 k2
    ^ ^ ^
    | | |
    w0 <------- w1 <------- w2
    ^ ^ ^
    | | |
    l0 l1 l2
    b0 k0 v0 b0 k1 v1 b0 k2 v2
    ^ ^ ^
    | | |
    u0 <------- u1 <------- u2
    ^ ^ ^
    | | |
    m0 m1 m2
    """

    S = T

    WB = w.new_zeros(NH, T, S, D) # d w_t / d beta_s

    for t in range(T):
    WB[:, t, :t] = einsum('n,nt,ntsk->nsk', -beta[:, t], K[:, :t, t], WB[:, :t, :t])
    WB[:, t, t] = k[:, t] - einsum('nt,ntk->nk', K[:, :t, t], w[:, :t])

    """
    w_0 = b_0 k_0
    w_1 = b_1 k_1 - b_1 k_0^T k_1 w_0
    w_2 = b_2 k_2 - b_2 k_0^T k_2 w_0 - b_2 k_1^T k_2 w_1
    w_t = b_t k_t - b_t \sum_{s=0}^{t-1} k_s^T k_t w_s
    # recompute w and u TK-style
    w = k.new_zeros(NH, T, D) # ntk
    u = v.new_zeros(NH, T, D) # ntw

    u_t = b_t v_t - b_t \sum_{s=0}^{t-1} k_s^T k_t u_s
    """
    bk = einsum('nt,ntk->ntk', beta, k)
    bv = einsum('nt,ntw->ntw', beta, v)

    WK = w.new_zeros(NH, T, S, D, D) # d w_t / d k_s # ntsij
    K = einsum('ntd,nsd->nts', k, k) # (T,T) matrix
    K = K.tril(diagonal=-1) # make_causal(0); set_diagonal(0)
    bKl = einsum('nt,nts->nts', -beta, K) # multiply each row of K by beta

    for t in range(T):
    # [s<t]
    for s in range(t):
    for l in range(t):
    c = (k[:, t] * k[:, l]).sum(-1)
    if s < l:
    WK[:, t, s] -= einsum('n,n,nij->nij', beta[:, t], c, WK[:, l, s])
    if s == l:
    WK[:, t, s] -= einsum('n,nj,ni->nij', beta[:, t], k[:, t], w[:, s])
    WK[:, t, s] -= einsum('n,n,nij->nij', beta[:, t], c, WK[:, l, l])
    # [s=t]

    WK[:, t, t, arange(D), arange(D)] = beta_[:, t]
    WK[:, t, t] += einsum('n,nsj,nsi->nij', -beta[:, t], k[:, :t], w[:, :t])
    c_w = einsum('nts,nsk->ntk', bKl, w)
    w[:, t] = bk[:, t, :] + c_w[:, t, :]
    c_u = einsum('nts,nsw->ntw', bKl, u)
    u[:, t] = bv[:, t, :] + c_u[:, t, :]

    """
    u_t = b_t v_t - b_t \sum_{s=0}^{t-1} k_s^T k_t u_s
    # compute gradients for d_k, d_v, d_beta
    d_k = k.new_zeros(NH, T, D) # nsk
    d_beta = beta.new_zeros(NH, T) # ns
    d_v = v.new_zeros(NH, T, D) # nsv

    d u_t / d v_s = [s=t] b_t
    + [s<t] - b_t \sum_{j=0}^{t-1} k_j^T k_t (d u_j / d v_s)
    """
    eye = torch.eye(D, device=k.device, dtype=k.dtype)
    eye = eye.unsqueeze(0).expand(NH, D, D)

    UV = u.new_zeros(NH, T, S, DV) # d u_t / d v_s
    w_bases = k - einsum('nts,nsk->ntk', K, w)
    u_bases = v - einsum('nts,nsw->ntw', K, u)

    for t in range(T):
    # [s<t]
    UV[:, t, :t] = einsum('n,nt,ntsv->nsv', -beta[:, t], K[:, :t, t], UV[:, :t, :t])
    for t in range(T-1,-1,-1):
    wk = einsum('njw,njk->nwk', w[:, :t], k[:, :t])
    wk = eye - wk
    uk = einsum('njw,njk->nwk', u[:, :t], k[:, :t])

    # [s=t]
    UV[:, t, t] = beta_[:, t]
    # d_k
    wst = einsum('n,nwk->nwk', beta[:, t], wk)
    d_k[:, t] += einsum('nw,nwk->nk', d_out_w[:, t], wst)

    """
    d u_t / d k_s =
    - b_t \sum_{l=0}^{t-1} (D_{k_s} k_l^T k_t u_l)
    ust = einsum('n,nwk->nwk', beta[:, t], uk)
    d_k[:, t] -= einsum('nw,nwk->nk', d_out_u[:, t], ust)

    D_{k_s} k_l^T k_t u_l = [s=t] k_l u_l^T # sum out dimensions of u_j after the outer product?
    D_{k_s} k_l^T k_t u_l = [s<t][s<l]: only u_l is a function of k_s
    [s<t][s=l]: product rule for u_l(k_s) and k_l=k_s
    """
    decay_w = einsum('nw,nsw->ns', d_out_w[:, t], w[:, :t])
    decay_u = einsum('nw,nsw->ns', d_out_u[:, t], u[:, :t])

    UK = u.new_zeros(NH, T, S, D, DV) # d u_t / d k_s
    d_k[:, :t] -= einsum('nk,ns->nsk', bk[:, t], decay_w)
    d_k[:, :t] -= einsum('nk,ns->nsk', bk[:, t], decay_u)

    for t in range(T):
    # [s<t]
    for s in range(t):
    for l in range(t):
    c = (k[:, t] * k[:, l]).sum(-1)
    if s < l:
    UK[:, t, s] -= einsum('n,n,nij->nij', beta[:, t], c, UK[:, l, s])
    if s == l:
    UK[:, t, s] -= einsum('n,nk,nv->nkv', beta[:, t], k[:, t], u[:, s])
    UK[:, t, s] -= einsum('n,n,nij->nij', beta[:, t], c, UK[:, l, l])
    # [s=t]
    UK[:, t, t] -= einsum('n,nsk,nsv->nkv', beta[:, t], k[:, :t], u[:, :t])

    UB = u.new_zeros(NH, T, S, D) # d u_t / d beta_s
    # d_beta
    d_beta[:, t] += einsum('nk,nk->n', w_bases[:, t], d_out_w[:, t])
    d_beta[:, t] += einsum('nk,nk->n', u_bases[:, t], d_out_u[:, t])

    for t in range(T):
    for s in range(t):
    for l in range(t):
    c = (k[:, t] * k[:, l]).sum(-1)
    UB[:, t, s] -= einsum('n,n,nj->nj', beta[:, t], c, UB[:, l, s])

    UB[:, t, t] = v[:, t] - einsum('ns,nsd->nd', K[:, :t, t], u[:, :t])
    # d_v
    d_v[:, t] = einsum('n,nv->nv', beta[:, t], d_out_u[:, t])

    d_beta = einsum('ntk,ntsk->ns', d_out_w, WB) + einsum('ntv,ntsv->ns', d_out_u, UB) # sum T and D out
    d_k = einsum('nti,ntsij->nsj', d_out_w, WK) + einsum('ntv,ntskv->nsk', d_out_u, UK) # sum T out
    d_v = einsum('ntv,ntsv->nsv', d_out_u, UV) # sum T out
    # backpropagate through time
    d_out_w += einsum('nj,nk->njk', bKl[:, t, :], d_out_w[:, t])
    d_out_u += einsum('nj,nk->njk', bKl[:, t, :], d_out_u[:, t])

    return d_k, d_v, d_beta

    @@ -436,7 +366,7 @@ def backward(ctx, d_out_w, d_out_u):


    def test_equal_decay_values_backward():
    NH, T, D = 2, 8, 3
    NH, T, D = 1, 4, 1

    q, k, v, beta = make_example(NH, T, D)
    w, u = decay_values(k, v, beta)
    @@ -677,4 +607,4 @@ def test_delta_chunkwise_backward():
    assert allclose(v1.grad, v.grad, atol=1e-5), 'v.grad is wrong'


    test_delta_chunkwise_backward()
    test_delta_chunkwise_backward()
  14. proger revised this gist Jul 2, 2024. 1 changed file with 32 additions and 32 deletions.
    64 changes: 32 additions & 32 deletions deltanet.py
    Original file line number Diff line number Diff line change
    @@ -120,38 +120,6 @@ def stitch1_forward(state, q, k, w, u):
    return y, state + state_add


    def stitch1_backward(d_y, d_state, state, q, k, w, u):
    state_decays = einsum('nvk,ntk->ntv', state, w)

    d_q1 = einsum('nsv,nvk->nsk', d_y, state) # prev_output
    d_state1 = einsum('nsv,nsk->nvk', d_y, q) # prev_output

    d_q2, d_k1, d_state_decays1 = causal_attend_backward(-d_y, q, k, state_decays) # delta

    d_k2 = einsum('nvk,ntv->ntk', d_state, u - state_decays) # state_add
    d_u = einsum('nvk,ntk->ntv', d_state, k) # state_add
    d_state_decays2 = einsum('nvk,ntk->ntv', -d_state, k) # state_add

    d_state_decays = d_state_decays1 + d_state_decays2
    d_w = einsum('ntv,nvk->ntk', d_state_decays, state) # state_decays
    d_state2 = einsum('ntv,ntk->nvk', d_state_decays, w) # state_decays

    return d_state + d_state1 + d_state2, d_q1 + d_q2, d_k1 + d_k2, d_w, d_u


    class Stitch1(torch.autograd.Function):
    @staticmethod
    def forward(ctx, state, q, k, w, u):
    y, new_state = stitch1_forward(state, q, k, w, u)
    ctx.save_for_backward(state, q, k, w, u)
    return y, new_state

    @staticmethod
    def backward(ctx, d_y, d_state):
    state, q, k, w, u = ctx.saved_tensors
    return stitch1_backward(d_y, d_state, state, q, k, w, u)


    def forward_ogloop(q, k, v, beta):
    "reference: w_t = w_{t-1} + beta_t (v_t - w_t k_t) k_t"
    NH, T, D = shape(q, k, v, beta)
    @@ -499,6 +467,38 @@ def test_equal_decay_values_backward():
    # %%


    def stitch1_backward(d_y, d_state, state, q, k, w, u):
    state_decays = einsum('nvk,ntk->ntv', state, w)

    d_q1 = einsum('nsv,nvk->nsk', d_y, state) # prev_output
    d_state1 = einsum('nsv,nsk->nvk', d_y, q) # prev_output

    d_q2, d_k1, d_state_decays1 = causal_attend_backward(-d_y, q, k, state_decays) # delta

    d_k2 = einsum('nvk,ntv->ntk', d_state, u - state_decays) # state_add
    d_u = einsum('nvk,ntk->ntv', d_state, k) # state_add
    d_state_decays2 = einsum('nvk,ntk->ntv', -d_state, k) # state_add

    d_state_decays = d_state_decays1 + d_state_decays2
    d_w = einsum('ntv,nvk->ntk', d_state_decays, state) # state_decays
    d_state2 = einsum('ntv,ntk->nvk', d_state_decays, w) # state_decays

    return d_state + d_state1 + d_state2, d_q1 + d_q2, d_k1 + d_k2, d_w, d_u


    class Stitch1(torch.autograd.Function):
    @staticmethod
    def forward(ctx, state, q, k, w, u):
    y, new_state = stitch1_forward(state, q, k, w, u)
    ctx.save_for_backward(state, q, k, w, u)
    return y, new_state

    @staticmethod
    def backward(ctx, d_y, d_state):
    state, q, k, w, u = ctx.saved_tensors
    return stitch1_backward(d_y, d_state, state, q, k, w, u)


    def stitch_backward(d_y_delta, q, k, w, u, C, chunk_size):
    NH, T, D = shape(q, k, None, None)

  15. proger revised this gist Jul 2, 2024. 1 changed file with 12 additions and 44 deletions.
    56 changes: 12 additions & 44 deletions deltanet.py
    Original file line number Diff line number Diff line change
    @@ -83,11 +83,11 @@ def forward_chunkwise(q, k, v, beta, chunk_size=2):
    y = causal_attend(q_, k_, u)

    # stitch chunks sequentially
    y_delta = forward_stitch(q, k, w, u, C=C, chunk_size=chunk_size)
    y_delta, _ = stitch_forward(q, k, w, u, C=C, chunk_size=chunk_size)
    return y.view(NH, T, D) + y_delta.view(NH, T, D)


    def forward_stitch(q, k, w, u, C, chunk_size):
    def stitch_forward(q, k, w, u, C, chunk_size):
    "stitch chunks sequentially"
    NH, T, D = shape(q, k, None, None)

    @@ -96,15 +96,17 @@ def forward_stitch(q, k, w, u, C, chunk_size):
    u = u.view(NH, C, chunk_size, D)
    w = w.view(NH, C, chunk_size, D)

    y_delta = u.new_zeros(NH, C, chunk_size, D)

    # materialize the state for the leading chunk
    state = einsum('ntv,ntk->nvk', u[:, 0], k_[:, 0])

    deltas = [u.new_zeros(NH, chunk_size, D)]
    for c in range(1, C):
    y_delta[:, c], state = stitch1_forward(state, q_[:, c], k_[:, c], w[:, c], u[:, c])
    y_delta1, state = stitch1_forward(state, q_[:, c], k_[:, c], w[:, c], u[:, c])
    deltas.append(y_delta1)

    return y_delta
    y_delta = torch.stack(deltas, dim=1)

    return y_delta, state


    def stitch1_forward(state, q, k, w, u):
    @@ -496,27 +498,6 @@ def test_equal_decay_values_backward():

    # %%

    def stitch_forward(q, k, w, u, C, chunk_size):
    "stitch chunks sequentially"
    NH, T, D = shape(q, k, None, None)

    q_ = q.view(NH, C, chunk_size, D)
    k_ = k.view(NH, C, chunk_size, D)
    u = u.view(NH, C, chunk_size, D)
    w = w.view(NH, C, chunk_size, D)

    # materialize the state for the leading chunk
    state = einsum('ntv,ntk->nvk', u[:, 0], k_[:, 0])

    deltas = [u.new_zeros(NH, chunk_size, D)]
    for c in range(1, C):
    y_delta1, state = stitch1_forward(state, q_[:, c], k_[:, c], w[:, c], u[:, c])
    deltas.append(y_delta1)

    y_delta = torch.stack(deltas, dim=1)

    return y_delta, state


    def stitch_backward(d_y_delta, q, k, w, u, C, chunk_size):
    NH, T, D = shape(q, k, None, None)
    @@ -628,23 +609,10 @@ def test_stitch_all(atol=1e-5):
    class DeltaChunkwise(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v, beta, chunk_size):
    NH, T, D = shape(q, k, v, beta)
    C = T // chunk_size
    q_, k_, v_, beta_ = (
    q.view(NH*C, chunk_size, D), k.view(NH*C, chunk_size, D),
    v.view(NH*C, chunk_size, D), beta.view(NH*C, chunk_size)
    )

    # evaluate all chunks in parallel
    w, u = decay_values(k_, v_, beta_)
    y = causal_attend(q_, k_, u)

    # stitch chunks sequentially
    y_delta = forward_stitch(q, k, w, u, C=C, chunk_size=chunk_size)

    y = forward_chunkwise(q, k, v, beta, chunk_size)
    ctx.save_for_backward(q, k, v, beta)
    ctx.chunk_size = chunk_size
    return y.view(NH, T, D) + y_delta.view(NH, T, D)
    return y

    @staticmethod
    def backward(ctx, d_y):
    @@ -685,7 +653,7 @@ def backward(ctx, d_y):
    return d_q, d_k, d_v, d_beta, None


    def test_delta_simple_backward():
    def test_delta_chunkwise_backward():
    NH, T, D = 2, 16, 2
    q1, k1, v1, beta1 = make_example(NH, T, D)

    @@ -709,4 +677,4 @@ def test_delta_simple_backward():
    assert allclose(v1.grad, v.grad, atol=1e-5), 'v.grad is wrong'


    test_delta_simple_backward()
    test_delta_chunkwise_backward()
  16. proger revised this gist Jul 2, 2024. 1 changed file with 4 additions and 22 deletions.
    26 changes: 4 additions & 22 deletions deltanet.py
    Original file line number Diff line number Diff line change
    @@ -318,24 +318,6 @@ def test_equal_attend_backward2(atol=1e-5):

    #%%

    def decay_values_forward(k, v, beta):
    "decay values applying deltanet forgetting rules (autograd-compatible, uses clone in the loop)"
    NH, T, D = shape(None, k, v, beta)

    w = k.new_zeros(NH, T, D)
    u = v.new_zeros(NH, T, D)
    beta_ = beta.unsqueeze(-1)

    K = einsum('nsd,ntd->nst', k, k) # (T,T) matrix

    w[:, 0] = beta_[:, 0] * k[:, 0]
    u[:, 0] = beta_[:, 0] * v[:, 0]
    for t in range(1,T):
    w[:, t] = beta_[:, t] * k.clone()[:, t] - beta_[:, t] * einsum('ns,nsd->nd', K[:, :t, t], w[:, :t].clone())
    u[:, t] = beta_[:, t] * v.clone()[:, t] - beta_[:, t] * einsum('ns,nsd->nd', K[:, :t, t], u[:, :t].clone())

    return w, u

    @no_grad()
    def decay_values_backward(d_out_w, d_out_u, k, v, beta):
    NH, T, D = shape(None, k, v, beta)
    @@ -473,7 +455,7 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta):
    class DecayValues(torch.autograd.Function):
    @staticmethod
    def forward(ctx, k, v, beta):
    w, u = decay_values_forward(k, v, beta)
    w, u = decay_values(k, v, beta)
    ctx.save_for_backward(k, v, beta)
    return w, u

    @@ -487,7 +469,7 @@ def test_equal_decay_values_backward():
    NH, T, D = 2, 8, 3

    q, k, v, beta = make_example(NH, T, D)
    w, u = decay_values_forward(k, v, beta)
    w, u = decay_values(k, v, beta)
    #(w + u - torch.ones_like(w)).pow(2).mean().backward()
    (w - torch.ones_like(w)).pow(2).mean().backward()

    @@ -591,7 +573,7 @@ def test_stitch_all(atol=1e-5):
    NH, T, D = 1, 4, 2
    C, chunk_size = 2, 2
    q, k, v, beta = make_example(NH, T, D)
    w, u = decay_values_forward(k, v, beta)
    w, u = decay_values(k, v, beta)

    q.requires_grad_()
    k.requires_grad_()
    @@ -612,7 +594,7 @@ def test_stitch_all(atol=1e-5):
    # print(u.grad, 'u.grad')

    q1, k1, v1, beta1 = make_example(NH, T, D)
    w1, u1 = decay_values_forward(k1, v1, beta1)
    w1, u1 = decay_values(k1, v1, beta1)

    q1.requires_grad_()
    k1.requires_grad_()
  17. proger revised this gist Jul 2, 2024. 1 changed file with 66 additions and 26 deletions.
    92 changes: 66 additions & 26 deletions deltanet.py
    Original file line number Diff line number Diff line change
    @@ -324,14 +324,15 @@ def decay_values_forward(k, v, beta):

    w = k.new_zeros(NH, T, D)
    u = v.new_zeros(NH, T, D)
    beta_ = beta.unsqueeze(-1)

    K = einsum('nsd,ntd->nst', k, k) # (T,T) matrix

    w[:, 0] = beta[:, 0] * k[:, 0]
    u[:, 0] = beta[:, 0] * v[:, 0]
    w[:, 0] = beta_[:, 0] * k[:, 0]
    u[:, 0] = beta_[:, 0] * v[:, 0]
    for t in range(1,T):
    w[:, t] = beta[:, t] * k.clone()[:, t] - beta[:, t] * einsum('ns,nsd->nd', K[:, :t, t], w[:, :t].clone())
    u[:, t] = beta[:, t] * v.clone()[:, t] - beta[:, t] * einsum('ns,nsd->nd', K[:, :t, t], u[:, :t].clone())
    w[:, t] = beta_[:, t] * k.clone()[:, t] - beta_[:, t] * einsum('ns,nsd->nd', K[:, :t, t], w[:, :t].clone())
    u[:, t] = beta_[:, t] * v.clone()[:, t] - beta_[:, t] * einsum('ns,nsd->nd', K[:, :t, t], u[:, :t].clone())

    return w, u

    @@ -340,6 +341,7 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta):
    NH, T, D = shape(None, k, v, beta)
    DV = D
    K = einsum('nsd,ntd->nst', k, k) # (T,T) matrix
    beta_ = beta.unsqueeze(-1)

    w = k.new_zeros(NH, T, D)
    u = v.new_zeros(NH, T, D)
    @@ -353,11 +355,11 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta):
    u_t = b_t v_t - b_t \sum_{s=0}^{t-1} k_s^T k_t u_s
    """
    w[:, 0] = beta[:, 0] * k[:, 0]
    u[:, 0] = beta[:, 0] * v[:, 0]
    w[:, 0] = beta_[:, 0] * k[:, 0]
    u[:, 0] = beta_[:, 0] * v[:, 0]
    for t in range(1,T):
    w[:, t] = beta[:, t] * k.clone()[:, t] - beta[:, t] * einsum('nsj,nj,nsd->nd', k[:, :t], k[:, t], w[:, :t].clone())
    u[:, t] = beta[:, t] * v.clone()[:, t] - beta[:, t] * einsum('nsj,nj,nsd->nd', k[:, :t], k[:, t], u[:, :t].clone())
    w[:, t] = beta_[:, t] * k.clone()[:, t] - beta_[:, t] * einsum('nsj,nj,nsd->nd', k[:, :t], k[:, t], w[:, :t].clone())
    u[:, t] = beta_[:, t] * v.clone()[:, t] - beta_[:, t] * einsum('nsj,nj,nsd->nd', k[:, :t], k[:, t], u[:, :t].clone())

    """
    b0 k0 b1 k1 b2 k2
    @@ -402,14 +404,14 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta):
    for l in range(t):
    c = (k[:, t] * k[:, l]).sum(-1)
    if s < l:
    WK[:, t, s] -= beta[:, t] * c * WK[:, l, s]
    WK[:, t, s] -= einsum('n,n,nij->nij', beta[:, t], c, WK[:, l, s])
    if s == l:
    WK[:, t, s] -= beta[:, t] * einsum('nj,ni->nij', k[:, t], w[:, s])
    WK[:, t, s] -= beta[:, t] * c * WK[:, l, l]
    WK[:, t, s] -= einsum('n,nj,ni->nij', beta[:, t], k[:, t], w[:, s])
    WK[:, t, s] -= einsum('n,n,nij->nij', beta[:, t], c, WK[:, l, l])
    # [s=t]

    WK[:, t, t, arange(D), arange(D)] = beta[:, t]
    WK[:, t, t] += -beta[:, t] * einsum('nsj,nsi->nij', k[:, :t], w[:, :t])
    WK[:, t, t, arange(D), arange(D)] = beta_[:, t]
    WK[:, t, t] += einsum('n,nsj,nsi->nij', -beta[:, t], k[:, :t], w[:, :t])

    """
    u_t = b_t v_t - b_t \sum_{s=0}^{t-1} k_s^T k_t u_s
    @@ -425,7 +427,7 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta):
    UV[:, t, :t] = einsum('n,nt,ntsv->nsv', -beta[:, t], K[:, :t, t], UV[:, :t, :t])

    # [s=t]
    UV[:, t, t] = beta[:, t]
    UV[:, t, t] = beta_[:, t]

    """
    d u_t / d k_s =
    @@ -444,20 +446,20 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta):
    for l in range(t):
    c = (k[:, t] * k[:, l]).sum(-1)
    if s < l:
    UK[:, t, s] -= beta[:, t] * c * UK[:, l, s]
    UK[:, t, s] -= einsum('n,n,nij->nij', beta[:, t], c, UK[:, l, s])
    if s == l:
    UK[:, t, s] -= beta[:, t] * einsum('nk,nv->nkv', k[:, t], u[:, s])
    UK[:, t, s] -= beta[:, t] * c * UK[:, l, l]
    UK[:, t, s] -= einsum('n,nk,nv->nkv', beta[:, t], k[:, t], u[:, s])
    UK[:, t, s] -= einsum('n,n,nij->nij', beta[:, t], c, UK[:, l, l])
    # [s=t]
    UK[:, t, t] -= beta[:, t] * einsum('nsk,nsv->nkv', k[:, :t], u[:, :t])
    UK[:, t, t] -= einsum('n,nsk,nsv->nkv', beta[:, t], k[:, :t], u[:, :t])

    UB = u.new_zeros(NH, T, S, D) # d u_t / d beta_s

    for t in range(T):
    for s in range(t):
    for l in range(t):
    c = (k[:, t] * k[:, l]).sum(-1)
    UB[:, t, s] -= beta[:, t] * c * UB[:, l, s]
    UB[:, t, s] -= einsum('n,n,nj->nj', beta[:, t], c, UB[:, l, s])

    UB[:, t, t] = v[:, t] - einsum('ns,nsd->nd', K[:, :t, t], u[:, :t])

    @@ -482,7 +484,7 @@ def backward(ctx, d_out_w, d_out_u):


    def test_equal_decay_values_backward():
    NH, T, D = 1, 8, 3
    NH, T, D = 2, 8, 3

    q, k, v, beta = make_example(NH, T, D)
    w, u = decay_values_forward(k, v, beta)
    @@ -669,22 +671,60 @@ def backward(ctx, d_y):
    chunk_size = ctx.chunk_size
    C = T // chunk_size

    d_y = d_y.view(NH, C, chunk_size, D)
    q_, k_, v_, beta_ = (
    q.view(NH*C, chunk_size, D), k.view(NH*C, chunk_size, D),
    v.view(NH*C, chunk_size, D), beta.view(NH*C, chunk_size)
    )

    w, u = decay_values(k_, v_, beta_)

    d_q_1, d_k_1, d_w, d_u1 = stitch_backward(d_y, q_, k_, w, u, C=C, chunk_size=chunk_size)
    d_q_1, d_k_1, d_w, d_u1 = stitch_backward(d_y, q, k, w, u, C=C, chunk_size=chunk_size)

    d_w = d_w.view(NH*C, chunk_size, D)
    d_u1 = d_u1.view(NH*C, chunk_size, D)
    d_y = d_y.view(NH*C, chunk_size, D)
    u = u.view(NH*C, chunk_size, D)

    d_q_2, d_k_2, d_u2 = causal_attend_backward(d_y, q_, k_, u)
    d_u = d_u1 + d_u2
    d_k_3, d_v_, d_beta_ = decay_values_backward(d_w, d_u, k_, v_, beta_)

    d_q = (d_q_1 + d_q_2).view(NH, T, D)
    d_k = (d_k_1 + d_k_2 + d_k_3).view(NH, T, D)
    d_v = d_v_.view(NH, T, D)
    d_q_1 = d_q_1.view(NH, T, D)
    d_q_2 = d_q_2.view(NH, C, chunk_size, D)
    d_q_2 = d_q_2.reshape(NH, T, D)
    d_q = d_q_1 + d_q_2

    d_k_2 = d_k_2.reshape(NH, T, D)
    d_k_3 = d_k_3.reshape(NH, T, D)
    d_k = d_k_1 + d_k_2 + d_k_3
    d_v = d_v_.reshape(NH, T, D)
    d_beta = d_beta_.view(NH, T)

    return d_q, d_k, d_v, d_beta, None
    return d_q, d_k, d_v, d_beta, None


    def test_delta_simple_backward():
    NH, T, D = 2, 16, 2
    q1, k1, v1, beta1 = make_example(NH, T, D)

    y1 = forward_chunkwise(q1, k1, v1, beta1, chunk_size=2)
    (y1 - torch.ones_like(y1).detach()).pow(2).mean().backward()

    q, k, v, beta = make_example(NH, T, D)
    y = DeltaChunkwise.apply(q, k, v, beta, 2)
    (y - torch.ones_like(y).detach()).pow(2).mean().backward()

    assert allclose(y1, y, atol=1e-5), 'y is wrong'

    # print(beta1.grad - beta.grad, 'beta.grad diff')
    # print(q1.grad - q.grad, 'q.grad diff')
    # print(k1.grad - k.grad, 'k.grad diff')
    # print(v1.grad - v.grad, 'v.grad diff')

    assert allclose(q1.grad, q.grad, atol=1e-5), 'q.grad is wrong'
    assert allclose(beta1.grad, beta.grad, atol=1e-5), 'beta.grad is wrong'
    assert allclose(k1.grad, k.grad, atol=1e-5), 'k.grad is wrong'
    assert allclose(v1.grad, v.grad, atol=1e-5), 'v.grad is wrong'


    test_delta_simple_backward()
  18. proger revised this gist Jul 2, 2024. 1 changed file with 16 additions and 7 deletions.
    23 changes: 16 additions & 7 deletions deltanet.py
    Original file line number Diff line number Diff line change
    @@ -385,7 +385,16 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta):
    WB[:, t, :t] = einsum('n,nt,ntsk->nsk', -beta[:, t], K[:, :t, t], WB[:, :t, :t])
    WB[:, t, t] = k[:, t] - einsum('nt,ntk->nk', K[:, :t, t], w[:, :t])

    WK = w.new_zeros(NH, T, S, D, D) # d w_t / d k_s
    """
    w_0 = b_0 k_0
    w_1 = b_1 k_1 - b_1 k_0^T k_1 w_0
    w_2 = b_2 k_2 - b_2 k_0^T k_2 w_0 - b_2 k_1^T k_2 w_1
    w_t = b_t k_t - b_t \sum_{s=0}^{t-1} k_s^T k_t w_s
    u_t = b_t v_t - b_t \sum_{s=0}^{t-1} k_s^T k_t u_s
    """

    WK = w.new_zeros(NH, T, S, D, D) # d w_t / d k_s # ntsij

    for t in range(T):
    # [s<t]
    @@ -395,11 +404,12 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta):
    if s < l:
    WK[:, t, s] -= beta[:, t] * c * WK[:, l, s]
    if s == l:
    WK[:, t, s] -= beta[:, t] * einsum('nk,nw->nkw', k[:, t], w[:, s])
    WK[:, t, s] -= beta[:, t] * einsum('nj,ni->nij', k[:, t], w[:, s])
    WK[:, t, s] -= beta[:, t] * c * WK[:, l, l]
    # [s=t]
    WK[:, t, t] = beta[:, t]
    WK[:, t, t] += -beta[:, t] * einsum('nsk,nsw->nkw', k[:, :t], w[:, :t])

    WK[:, t, t, arange(D), arange(D)] = beta[:, t]
    WK[:, t, t] += -beta[:, t] * einsum('nsj,nsi->nij', k[:, :t], w[:, :t])

    """
    u_t = b_t v_t - b_t \sum_{s=0}^{t-1} k_s^T k_t u_s
    @@ -452,8 +462,7 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta):
    UB[:, t, t] = v[:, t] - einsum('ns,nsd->nd', K[:, :t, t], u[:, :t])

    d_beta = einsum('ntk,ntsk->ns', d_out_w, WB) + einsum('ntv,ntsv->ns', d_out_u, UB) # sum T and D out
    ##### what is wrong with WK???
    d_k = einsum('ntk,ntskw->nsk', d_out_w, WK) + einsum('ntv,ntskv->nsk', d_out_u, UK) # sum T out
    d_k = einsum('nti,ntsij->nsj', d_out_w, WK) + einsum('ntv,ntskv->nsk', d_out_u, UK) # sum T out
    d_v = einsum('ntv,ntsv->nsv', d_out_u, UV) # sum T out

    return d_k, d_v, d_beta
    @@ -473,7 +482,7 @@ def backward(ctx, d_out_w, d_out_u):


    def test_equal_decay_values_backward():
    NH, T, D = 1, 4, 2
    NH, T, D = 1, 8, 3

    q, k, v, beta = make_example(NH, T, D)
    w, u = decay_values_forward(k, v, beta)
  19. proger revised this gist Jul 1, 2024. 1 changed file with 101 additions and 201 deletions.
    302 changes: 101 additions & 201 deletions deltanet.py
    Original file line number Diff line number Diff line change
    @@ -336,8 +336,9 @@ def decay_values_forward(k, v, beta):
    return w, u

    @no_grad()
    def decay_values_backward(k, v, beta):
    def decay_values_backward(d_out_w, d_out_u, k, v, beta):
    NH, T, D = shape(None, k, v, beta)
    DV = D
    K = einsum('nsd,ntd->nst', k, k) # (T,T) matrix

    w = k.new_zeros(NH, T, D)
    @@ -381,24 +382,24 @@ def decay_values_backward(k, v, beta):
    WB = w.new_zeros(NH, T, S, D) # d w_t / d beta_s

    for t in range(T):
    WB[:, t, :t] = einsum('n,nt,ntsd->nsd', -beta[:, t], K[:, :t, t], WB[:, :t, :t])
    WB[:, t, t] = k[:, t] - einsum('nt,ntd->nd', K[:, :t, t], w[:, :t])
    WB[:, t, :t] = einsum('n,nt,ntsk->nsk', -beta[:, t], K[:, :t, t], WB[:, :t, :t])
    WB[:, t, t] = k[:, t] - einsum('nt,ntk->nk', K[:, :t, t], w[:, :t])

    WK = w.new_zeros(NH, T, S, D) # d w_t / d k_s
    WK = w.new_zeros(NH, T, S, D, D) # d w_t / d k_s

    for t in range(T):
    # [s<t]
    for s in range(t):
    # for j in range(t): [s<j]
    WK[:, t, s] -= beta[:, t] * einsum('nj,njd->nd', K[:, s+1:t, t], WK[:, s+1:t, s])

    # for s in range(t): for j in range(t): [s<t][s=j]
    WK[:, t, :t] -= beta[:, t] * einsum('nh,nsd->nsh', k[:, t], w[:, :t])
    WK[:, t, :t] -= beta[:, t] * einsum('ns,nsd->nsd', K[:, :t, t], WK[:, arange(t), arange(t)])

    for l in range(t):
    c = (k[:, t] * k[:, l]).sum(-1)
    if s < l:
    WK[:, t, s] -= beta[:, t] * c * WK[:, l, s]
    if s == l:
    WK[:, t, s] -= beta[:, t] * einsum('nk,nw->nkw', k[:, t], w[:, s])
    WK[:, t, s] -= beta[:, t] * c * WK[:, l, l]
    # [s=t]
    WK[:, t, t] = beta[:, t]
    WK[:, t, t] += -beta[:, t] * einsum('nsj,nsi->nj', k[:, :t], w[:, :t])
    WK[:, t, t] += -beta[:, t] * einsum('nsk,nsw->nkw', k[:, :t], w[:, :t])

    """
    u_t = b_t v_t - b_t \sum_{s=0}^{t-1} k_s^T k_t u_s
    @@ -407,11 +408,11 @@ def decay_values_backward(k, v, beta):
    + [s<t] - b_t \sum_{j=0}^{t-1} k_j^T k_t (d u_j / d v_s)
    """

    UV = u.new_zeros(NH, T, S, D) # d u_t / d v_s
    UV = u.new_zeros(NH, T, S, DV) # d u_t / d v_s

    for t in range(T):
    # [s<t]
    UV[:, t, :t] = einsum('n,nt,ntsd->nsd', -beta[:, t], K[:, :t, t], UV[:, :t, :t])
    UV[:, t, :t] = einsum('n,nt,ntsv->nsv', -beta[:, t], K[:, :t, t], UV[:, :t, :t])

    # [s=t]
    UV[:, t, t] = beta[:, t]
    @@ -425,100 +426,8 @@ def decay_values_backward(k, v, beta):
    [s<t][s=l]: product rule for u_l(k_s) and k_l=k_s
    """

    UK = u.new_zeros(NH, T, S, D) # d u_t / d k_s

    for t in range(T):
    # [s<t]
    for s in range(t):
    for l in range(t):
    c = (k[:, t] * k[:, l]).sum(-1)
    if s < l:
    UK[:, t, s] -= beta[:, t] * c * UK[:, l, s]
    if s == l:
    UK[:, t, s] -= beta[:, t] * einsum('nd,nh->nd', k[:, t], u[:, s])
    UK[:, t, s] -= beta[:, t] * c * UK[:, l, l]

    # [s=t]
    UK[:, t, t] -= beta[:, t] * einsum('nsd,nsh->nd', k[:, :t], u[:, :t])

    UB = u.new_zeros(NH, T, S, D) # d u_t / d beta_s

    for t in range(T):
    for s in range(t):
    for l in range(t):
    c = (k[:, t] * k[:, l]).sum(-1)
    UB[:, t, s] -= beta[:, t] * c * UB[:, l, s]

    UB[:, t, t] = v[:, t] - einsum('ns,nsd->nd', K[:, :t, t], u[:, :t])

    d_beta = einsum('ntsd->ns', WB) + einsum('ntsd->ns', UB) # sum T and D out
    d_k = einsum('ntsd->nsd', WK) + einsum('ntsd->nsd', UK) # sum T out

    d_v = einsum('ntsd->nsd', UV) # sum T out

    return w, u, d_k, d_v, d_beta


    def test_equal_decay_values_backward():
    NH, T, D = 1, 4, 2

    q, k, v, beta = make_example(NH, T, D)

    ## to try these modify summations in d_beta and d_k above
    # k, v, beta = make()
    # w, u = decay_values_forward(k, v, beta)
    # w.sum().backward()
    # w, u, k_grad, v_grad, beta_grad = decay_values_backward(k, v, beta)
    # # print(k.grad, 'k.grad')
    # # print(k_grad, 'k_grad')
    # # print(k.grad - k_grad, 'diff')
    # assert allclose(k.grad, k_grad, atol=1e-5), 'k_grad is wrong'
    # assert allclose(beta.grad, beta_grad, atol=1e-5), 'beta_grad is wrong'

    q, k, v, beta = make_example(NH, T, D)
    w, u = decay_values_forward(k, v, beta)
    (w + u).sum().backward()
    w, u, k_grad, v_grad, beta_grad = decay_values_backward(k, v, beta)
    # print(v.grad, 'v.grad', v.grad.shape)
    # print(v_grad, 'v_grad')
    # print(v.grad - v_grad, 'v diff')
    assert allclose(v.grad, v_grad, atol=1e-5), 'v_grad is wrong'

    # print(k.grad, 'k.grad du')
    # print(k_grad, 'k_grad du')
    # print(k.grad - k_grad, 'diff du')
    assert allclose(k.grad, k_grad, atol=1e-5), 'k_grad wrt u is wrong'

    # print(beta.grad, 'beta.grad du')
    # print(beta_grad, 'beta_grad du')
    assert allclose(beta.grad, beta_grad, atol=1e-5), 'beta_grad wrt u is wrong'

    test_equal_decay_values_backward()

    #%%


    @no_grad()
    def decay_values_backward_only_u(d_out, k, v, beta):
    "trimmed version of decay_values_backward that does not use w"

    NH, T, D = shape(None, k, v, beta)
    DV = D
    K = einsum('nsd,ntd->nst', k, k) # (T,T) matrix

    u = v.new_zeros(NH, T, D)
    u[:, 0] = beta[:, 0] * v[:, 0]
    for t in range(1,T):
    u[:, t] = beta[:, t] * v.clone()[:, t] - beta[:, t] * einsum('nsj,nj,nsd->nd', k[:, :t], k[:, t], u[:, :t].clone())

    S = T

    UV = u.new_zeros(NH, T, S, D) # d u_t / d v_s
    for t in range(T):
    UV[:, t, :t] = einsum('n,nt,ntsd->nsd', -beta[:, t], K[:, :t, t], UV[:, :t, :t])
    UV[:, t, t] = beta[:, t]

    UK = u.new_zeros(NH, T, S, D, DV) # d u_t / d k_s

    for t in range(T):
    # [s<t]
    for s in range(t):
    @@ -542,116 +451,55 @@ def decay_values_backward_only_u(d_out, k, v, beta):

    UB[:, t, t] = v[:, t] - einsum('ns,nsd->nd', K[:, :t, t], u[:, :t])

    d_beta = einsum('ntv,ntsv->ns', d_out, UB)
    d_k = einsum('ntv,ntskv->nsk', d_out, UK)
    d_v = einsum('ntv,ntsv->nsv', d_out, UV)

    return u, d_k, d_v, d_beta

    d_beta = einsum('ntk,ntsk->ns', d_out_w, WB) + einsum('ntv,ntsv->ns', d_out_u, UB) # sum T and D out
    ##### what is wrong with WK???
    d_k = einsum('ntk,ntskw->nsk', d_out_w, WK) + einsum('ntv,ntskv->nsk', d_out_u, UK) # sum T out
    d_v = einsum('ntv,ntsv->nsv', d_out_u, UV) # sum T out

    def forward_simple1(q, k, v, beta):
    NH, T, D = shape(q, k, v, beta)
    K = einsum('nsd,ntd->nst', k, k) # (T,T) matrix

    u = einsum('nt,ntd->ntd', beta, v.clone())
    for t in range(1,T):
    u[:, t] -= einsum('n,nt,ntd->nd', beta[:,t], K[:, :t, t], u[:, :t].clone())
    return d_k, d_v, d_beta

    return u, causal_attend(q, k, u)


    class DeltaSimple(torch.autograd.Function):
    class DecayValues(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v, beta):
    u, y = forward_simple1(q, k, v, beta)
    ctx.save_for_backward(q, k, v, beta, u)
    return y
    def forward(ctx, k, v, beta):
    w, u = decay_values_forward(k, v, beta)
    ctx.save_for_backward(k, v, beta)
    return w, u

    @staticmethod
    def backward(ctx, d_out):
    q, k, v, beta, u = ctx.saved_tensors

    d_q, d_k1, d_u = causal_attend_backward(d_out, q, k, u)
    u, d_k, d_v, d_beta = decay_values_backward_only_u(d_u, k, v, beta)
    d_k = d_k1 + d_k

    return d_q, d_k, d_v, d_beta
    def backward(ctx, d_out_w, d_out_u):
    k, v, beta = ctx.saved_tensors
    return decay_values_backward(d_out_w, d_out_u, k, v, beta)


    def test_delta_simple_backward():
    NH, T, D = 1, 3, 2
    q1, k1, v1, beta1 = make_example(NH, T, D)
    u1, y1 = forward_simple1(q1, k1, v1, beta1)
    (y1 - torch.ones_like(y1).detach()).pow(2).mean().backward()

    q, k, v, beta = make_example(NH, T, D)
    y = DeltaSimple.apply(q, k, v, beta)
    (y - torch.ones_like(y).detach()).pow(2).mean().backward()

    assert allclose(y1, y, atol=1e-5), 'y is wrong'

    print(beta1.grad - beta.grad, 'beta.grad diff')
    #print(q1.grad - q.grad, 'q.grad diff')
    print(k1.grad - k.grad, 'k.grad diff')
    print(v1.grad - v.grad, 'v.grad diff')

    assert allclose(q1.grad, q.grad, atol=1e-5), 'q.grad is wrong'
    assert allclose(beta1.grad, beta.grad, atol=1e-5), 'beta.grad is wrong'
    assert allclose(k1.grad, k.grad, atol=1e-5), 'k.grad is wrong'
    assert allclose(v1.grad, v.grad, atol=1e-5), 'v.grad is wrong'


    test_delta_simple_backward()

    # %%


    def test_stitch1(atol=1e-5):
    def test_equal_decay_values_backward():
    NH, T, D = 1, 4, 2

    q, k, v, beta = make_example(NH, T, D)
    w, u = decay_values_forward(k, v, beta)
    state = torch.randn_like(einsum('ntv,ntk->nvk', v, k))

    q.requires_grad_()
    k.requires_grad_()
    v.requires_grad_()
    beta.requires_grad_()
    w.retain_grad()
    u.retain_grad()
    state.requires_grad_()

    y, new_state = stitch1_forward(state, q, k, w, u)
    loss = (y - torch.ones_like(y)).pow(2).mean() + (new_state - torch.ones_like(new_state)).pow(2).mean()
    loss.backward()
    #(w + u - torch.ones_like(w)).pow(2).mean().backward()
    (w - torch.ones_like(w)).pow(2).mean().backward()

    q1, k1, v1, beta1 = make_example(NH, T, D)
    w1, u1 = decay_values_forward(k1, v1, beta1)
    state1 = torch.randn_like(einsum('ntv,ntk->nvk', v1, k1))

    q1.requires_grad_()
    k1.requires_grad_()
    v1.requires_grad_()
    beta1.requires_grad_()
    w1.retain_grad()
    u1.retain_grad()
    state1.requires_grad_()
    w1, u1 = DecayValues.apply(k1, v1, beta1)
    #(w1 + u1 - torch.ones_like(w1)).pow(2).mean().backward()
    (w1 - torch.ones_like(w1)).pow(2).mean().backward()

    y1, new_state1 = Stitch1.apply(state1, q1, k1, w1, u1)
    loss = (y1 - torch.ones_like(y1)).pow(2).mean() + (new_state1 - torch.ones_like(new_state1)).pow(2).mean()
    loss.backward()
    # print(v.grad, 'v.grad', v.grad.shape)
    # print(v1.grad, 'v1.grad')
    # print(v.grad - v1.grad, 'v diff')
    # assert allclose(v.grad, v1.grad, atol=1e-5), 'v1_grad is wrong'

    assert allclose(y, y1, atol=atol), 'y is wrong'
    assert allclose(new_state, new_state1, atol=atol), 'new_state is wrong'
    print(k.grad, 'k.grad du')
    print(k1.grad, 'k1.grad du')
    print(k.grad - k1.grad, 'diff du')
    assert allclose(k.grad, k1.grad, atol=1e-5), 'k1_grad is wrong'

    assert allclose(u.grad, u1.grad, atol=atol), 'u.grad is wrong'
    assert allclose(v.grad, v1.grad, atol=atol), 'v.grad is wrong'
    assert allclose(k.grad, k1.grad, atol=atol), 'k.grad is wrong'
    assert allclose(q.grad, q1.grad, atol=atol), 'q.grad is wrong'
    assert allclose(beta.grad, beta1.grad, atol=atol), 'beta.grad is wrong'
    assert allclose(w.grad, w1.grad, atol=atol), 'w.grad is wrong'
    assert allclose(state.grad, state1.grad, atol=atol), 'state.grad is wrong'
    # print(beta.grad, 'beta.grad du')
    # print(beta1.grad, 'beta1.grad du')
    assert allclose(beta.grad, beta1.grad, atol=1e-5), 'beta1.grad is wrong'

    test_stitch1()
    test_equal_decay_values_backward()

    # %%

    @@ -778,4 +626,56 @@ def test_stitch_all(atol=1e-5):
    assert allclose(w.grad, w1.grad, atol=atol), 'w.grad is wrong'


    test_stitch_all()
    test_stitch_all()


    #%%


    class DeltaChunkwise(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v, beta, chunk_size):
    NH, T, D = shape(q, k, v, beta)
    C = T // chunk_size
    q_, k_, v_, beta_ = (
    q.view(NH*C, chunk_size, D), k.view(NH*C, chunk_size, D),
    v.view(NH*C, chunk_size, D), beta.view(NH*C, chunk_size)
    )

    # evaluate all chunks in parallel
    w, u = decay_values(k_, v_, beta_)
    y = causal_attend(q_, k_, u)

    # stitch chunks sequentially
    y_delta = forward_stitch(q, k, w, u, C=C, chunk_size=chunk_size)

    ctx.save_for_backward(q, k, v, beta)
    ctx.chunk_size = chunk_size
    return y.view(NH, T, D) + y_delta.view(NH, T, D)

    @staticmethod
    def backward(ctx, d_y):
    q, k, v, beta = ctx.saved_tensors
    NH, T, D = shape(q, k, v, beta)
    chunk_size = ctx.chunk_size
    C = T // chunk_size

    d_y = d_y.view(NH, C, chunk_size, D)
    q_, k_, v_, beta_ = (
    q.view(NH*C, chunk_size, D), k.view(NH*C, chunk_size, D),
    v.view(NH*C, chunk_size, D), beta.view(NH*C, chunk_size)
    )

    w, u = decay_values(k_, v_, beta_)

    d_q_1, d_k_1, d_w, d_u1 = stitch_backward(d_y, q_, k_, w, u, C=C, chunk_size=chunk_size)
    d_q_2, d_k_2, d_u2 = causal_attend_backward(d_y, q_, k_, u)
    d_u = d_u1 + d_u2
    d_k_3, d_v_, d_beta_ = decay_values_backward(d_w, d_u, k_, v_, beta_)

    d_q = (d_q_1 + d_q_2).view(NH, T, D)
    d_k = (d_k_1 + d_k_2 + d_k_3).view(NH, T, D)
    d_v = d_v_.view(NH, T, D)
    d_beta = d_beta_.view(NH, T)

    return d_q, d_k, d_v, d_beta, None
  20. proger revised this gist Jul 1, 2024. 1 changed file with 218 additions and 11 deletions.
    229 changes: 218 additions & 11 deletions deltanet.py
    Original file line number Diff line number Diff line change
    @@ -96,28 +96,58 @@ def forward_stitch(q, k, w, u, C, chunk_size):
    u = u.view(NH, C, chunk_size, D)
    w = w.view(NH, C, chunk_size, D)

    y_delta = q.new_zeros(NH, C, chunk_size, D)
    y_delta = u.new_zeros(NH, C, chunk_size, D)

    # materialize the state for the leading chunk
    state = einsum('ntv,ntk->nvk', u[:, 0], k_[:, 0])
    mask = q.new_ones(chunk_size, chunk_size).tril()

    for c in range(1, C):
    y_delta[:, c], state = stitch1(state, q_[:, c], k_[:, c], w[:, c], u[:, c], mask)
    y_delta[:, c], state = stitch1_forward(state, q_[:, c], k_[:, c], w[:, c], u[:, c])

    return y_delta


    def stitch1(state, q, k, w, u, mask):
    def stitch1_forward(state, q, k, w, u):
    state_decays = einsum('nvk,ntk->ntv', state, w)
    state_add = einsum('ntv,ntk->nvk', u - state_decays, k)

    delta = causal_attend(q, k, state_decays)
    prev_output = einsum('nvk,nsk->nsv', state, q)
    state_decays = einsum('nvk,nsk->nsv', state, w)
    y = prev_output - delta

    return y, state + state_add


    def stitch1_backward(d_y, d_state, state, q, k, w, u):
    state_decays = einsum('nvk,ntk->ntv', state, w)

    d_q1 = einsum('nsv,nvk->nsk', d_y, state) # prev_output
    d_state1 = einsum('nsv,nsk->nvk', d_y, q) # prev_output

    d_q2, d_k1, d_state_decays1 = causal_attend_backward(-d_y, q, k, state_decays) # delta

    # delta = causal_attend(q, k, state_decays)
    delta = einsum('nsk,ntk,st,ntv->nsv', q, k, mask, state_decays)
    y_delta1 = prev_output - delta
    d_k2 = einsum('nvk,ntv->ntk', d_state, u - state_decays) # state_add
    d_u = einsum('nvk,ntk->ntv', d_state, k) # state_add
    d_state_decays2 = einsum('nvk,ntk->ntv', -d_state, k) # state_add

    state_add = einsum('nsv,nsk->nvk', u - state_decays, k)
    return y_delta1, state + state_add
    d_state_decays = d_state_decays1 + d_state_decays2
    d_w = einsum('ntv,nvk->ntk', d_state_decays, state) # state_decays
    d_state2 = einsum('ntv,ntk->nvk', d_state_decays, w) # state_decays

    return d_state + d_state1 + d_state2, d_q1 + d_q2, d_k1 + d_k2, d_w, d_u


    class Stitch1(torch.autograd.Function):
    @staticmethod
    def forward(ctx, state, q, k, w, u):
    y, new_state = stitch1_forward(state, q, k, w, u)
    ctx.save_for_backward(state, q, k, w, u)
    return y, new_state

    @staticmethod
    def backward(ctx, d_y, d_state):
    state, q, k, w, u = ctx.saved_tensors
    return stitch1_backward(d_y, d_state, state, q, k, w, u)


    def forward_ogloop(q, k, v, beta):
    @@ -571,4 +601,181 @@ def test_delta_simple_backward():
    assert allclose(v1.grad, v.grad, atol=1e-5), 'v.grad is wrong'


    test_delta_simple_backward()
    test_delta_simple_backward()

    # %%


    def test_stitch1(atol=1e-5):
    NH, T, D = 1, 4, 2
    q, k, v, beta = make_example(NH, T, D)
    w, u = decay_values_forward(k, v, beta)
    state = torch.randn_like(einsum('ntv,ntk->nvk', v, k))

    q.requires_grad_()
    k.requires_grad_()
    v.requires_grad_()
    beta.requires_grad_()
    w.retain_grad()
    u.retain_grad()
    state.requires_grad_()

    y, new_state = stitch1_forward(state, q, k, w, u)
    loss = (y - torch.ones_like(y)).pow(2).mean() + (new_state - torch.ones_like(new_state)).pow(2).mean()
    loss.backward()

    q1, k1, v1, beta1 = make_example(NH, T, D)
    w1, u1 = decay_values_forward(k1, v1, beta1)
    state1 = torch.randn_like(einsum('ntv,ntk->nvk', v1, k1))

    q1.requires_grad_()
    k1.requires_grad_()
    v1.requires_grad_()
    beta1.requires_grad_()
    w1.retain_grad()
    u1.retain_grad()
    state1.requires_grad_()

    y1, new_state1 = Stitch1.apply(state1, q1, k1, w1, u1)
    loss = (y1 - torch.ones_like(y1)).pow(2).mean() + (new_state1 - torch.ones_like(new_state1)).pow(2).mean()
    loss.backward()

    assert allclose(y, y1, atol=atol), 'y is wrong'
    assert allclose(new_state, new_state1, atol=atol), 'new_state is wrong'

    assert allclose(u.grad, u1.grad, atol=atol), 'u.grad is wrong'
    assert allclose(v.grad, v1.grad, atol=atol), 'v.grad is wrong'
    assert allclose(k.grad, k1.grad, atol=atol), 'k.grad is wrong'
    assert allclose(q.grad, q1.grad, atol=atol), 'q.grad is wrong'
    assert allclose(beta.grad, beta1.grad, atol=atol), 'beta.grad is wrong'
    assert allclose(w.grad, w1.grad, atol=atol), 'w.grad is wrong'
    assert allclose(state.grad, state1.grad, atol=atol), 'state.grad is wrong'

    test_stitch1()

    # %%

    def stitch_forward(q, k, w, u, C, chunk_size):
    "stitch chunks sequentially"
    NH, T, D = shape(q, k, None, None)

    q_ = q.view(NH, C, chunk_size, D)
    k_ = k.view(NH, C, chunk_size, D)
    u = u.view(NH, C, chunk_size, D)
    w = w.view(NH, C, chunk_size, D)

    # materialize the state for the leading chunk
    state = einsum('ntv,ntk->nvk', u[:, 0], k_[:, 0])

    deltas = [u.new_zeros(NH, chunk_size, D)]
    for c in range(1, C):
    y_delta1, state = stitch1_forward(state, q_[:, c], k_[:, c], w[:, c], u[:, c])
    deltas.append(y_delta1)

    y_delta = torch.stack(deltas, dim=1)

    return y_delta, state


    def stitch_backward(d_y_delta, q, k, w, u, C, chunk_size):
    NH, T, D = shape(q, k, None, None)

    d_y_delta = d_y_delta.view(NH, C, chunk_size, D)
    q_ = q.view(NH, C, chunk_size, D)
    k_ = k.view(NH, C, chunk_size, D)
    u = u.view(NH, C, chunk_size, D)
    w = w.view(NH, C, chunk_size, D)

    d_q_ = q.new_zeros(NH, C, chunk_size, D)
    d_k_ = k.new_zeros(NH, C, chunk_size, D)
    d_w = w.new_zeros(NH, C, chunk_size, D)
    d_u = u.new_zeros(NH, C, chunk_size, D)
    d_state = w.new_zeros(NH, D, D) # NHVK
    y_delta = u.new_zeros(NH, C, chunk_size, D) # leading chunk has zero delta

    # storing all states for BPTT
    states = k.new_zeros(NH, C, D, D) # NHCVK
    # materialize the state for the leading chunk
    states[:, 0] = einsum('ntv,ntk->nvk', u[:, 0], k_[:, 0])

    for c in range(1, C):
    y_delta[:, c], states[:, c] = stitch1_forward(states[:, c-1], q_[:, c], k_[:, c], w[:, c], u[:, c])

    for c in range(C-1, 0, -1):
    (
    d_state, d_q_[:, c], d_k_[:, c], d_w[:, c], d_u[:, c]
    ) = stitch1_backward(d_y_delta[:, c], d_state, states[:, c-1], q_[:, c], k_[:, c], w[:, c], u[:, c])

    (
    d_state, d_q_[:, 0], d_k_[:, 0], d_w[:, 0], d_u[:, 0]
    ) = stitch1_backward(d_y_delta[:, 0], d_state, torch.zeros_like(d_state), q_[:, 0], k_[:, 0], w[:, 0], u[:, 0])

    return d_q_.view(NH, T, D), d_k_.view(NH, T, D), d_w.view(NH, T, D), d_u.view(NH, T, D)


    class Stitch(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, w, u, C, chunk_size):
    y_delta, state = stitch_forward(q, k, w, u, C, chunk_size)
    ctx.save_for_backward(q, k, w, u)
    ctx.C = C
    ctx.chunk_size = chunk_size
    return y_delta

    @staticmethod
    def backward(ctx, d_y_delta):
    q, k, w, u = ctx.saved_tensors
    return *stitch_backward(d_y_delta, q, k, w, u, ctx.C, ctx.chunk_size), None, None


    def test_stitch_all(atol=1e-5):
    NH, T, D = 1, 4, 2
    C, chunk_size = 2, 2
    q, k, v, beta = make_example(NH, T, D)
    w, u = decay_values_forward(k, v, beta)

    q.requires_grad_()
    k.requires_grad_()
    v.requires_grad_()
    beta.requires_grad_()
    w.retain_grad()
    u.retain_grad()

    y, new_state = stitch_forward(q, k, w, u, C=C, chunk_size=chunk_size)
    loss = (y - torch.ones_like(y)).pow(2).mean()
    loss.backward()

    # print(q.grad, 'q.grad')
    # print(k.grad, 'k.grad')
    # print(v.grad, 'v.grad')
    # print(beta.grad, 'beta.grad')
    # print(w.grad, 'w.grad')
    # print(u.grad, 'u.grad')

    q1, k1, v1, beta1 = make_example(NH, T, D)
    w1, u1 = decay_values_forward(k1, v1, beta1)

    q1.requires_grad_()
    k1.requires_grad_()
    v1.requires_grad_()
    beta1.requires_grad_()
    w1.retain_grad()
    u1.retain_grad()

    y1 = Stitch.apply(q1, k1, w1, u1, C, chunk_size)
    loss = (y1 - torch.ones_like(y1)).pow(2).mean()
    loss.backward()

    assert allclose(y, y1, atol=atol), 'y is wrong'

    assert allclose(u.grad, u1.grad, atol=atol), 'u.grad is wrong'
    assert allclose(v.grad, v1.grad, atol=atol), 'v.grad is wrong'
    # print(k.grad, 'k.grad')
    # print(k1.grad, 'k1.grad')
    assert allclose(k.grad, k1.grad, atol=atol), 'k.grad is wrong'
    assert allclose(q.grad, q1.grad, atol=atol), 'q.grad is wrong'
    assert allclose(beta.grad, beta1.grad, atol=atol), 'beta.grad is wrong'
    assert allclose(w.grad, w1.grad, atol=atol), 'w.grad is wrong'


    test_stitch_all()
  21. proger revised this gist Jul 1, 2024. 1 changed file with 2 additions and 4 deletions.
    6 changes: 2 additions & 4 deletions deltanet.py
    Original file line number Diff line number Diff line change
    @@ -103,9 +103,7 @@ def forward_stitch(q, k, w, u, C, chunk_size):
    mask = q.new_ones(chunk_size, chunk_size).tril()

    for c in range(1, C):
    y_delta1, state_add = stitch1(state, q_[:, c], k_[:, c], w[:, c], u[:, c], mask)
    y_delta[:, c] = y_delta1
    state.add_(state_add)
    y_delta[:, c], state = stitch1(state, q_[:, c], k_[:, c], w[:, c], u[:, c], mask)

    return y_delta

    @@ -119,7 +117,7 @@ def stitch1(state, q, k, w, u, mask):
    y_delta1 = prev_output - delta

    state_add = einsum('nsv,nsk->nvk', u - state_decays, k)
    return y_delta1, state_add
    return y_delta1, state + state_add


    def forward_ogloop(q, k, v, beta):
  22. proger revised this gist Jul 1, 2024. 1 changed file with 15 additions and 5 deletions.
    20 changes: 15 additions & 5 deletions deltanet.py
    Original file line number Diff line number Diff line change
    @@ -100,16 +100,26 @@ def forward_stitch(q, k, w, u, C, chunk_size):

    # materialize the state for the leading chunk
    state = einsum('ntv,ntk->nvk', u[:, 0], k_[:, 0])
    mask = q.new_ones(chunk_size, chunk_size).tril()

    for c in range(1, C):
    prev_output = einsum('nvk,ntk->ntv', state, q_[:, c])
    state_decays = einsum('nvk,ntk->ntv', state, w[:, c])
    y_delta1, state_add = stitch1(state, q_[:, c], k_[:, c], w[:, c], u[:, c], mask)
    y_delta[:, c] = y_delta1
    state.add_(state_add)

    y_delta[:, c] = prev_output - causal_attend(q_[:, c], k_[:, c], state_decays)
    return y_delta

    state.add_(einsum('ntv,ntk->nvk', u[:, c] - state_decays, k_[:, c]))

    return y_delta
    def stitch1(state, q, k, w, u, mask):
    prev_output = einsum('nvk,nsk->nsv', state, q)
    state_decays = einsum('nvk,nsk->nsv', state, w)

    # delta = causal_attend(q, k, state_decays)
    delta = einsum('nsk,ntk,st,ntv->nsv', q, k, mask, state_decays)
    y_delta1 = prev_output - delta

    state_add = einsum('nsv,nsk->nvk', u - state_decays, k)
    return y_delta1, state_add


    def forward_ogloop(q, k, v, beta):
  23. proger revised this gist Jul 1, 2024. 1 changed file with 25 additions and 15 deletions.
    40 changes: 25 additions & 15 deletions deltanet.py
    Original file line number Diff line number Diff line change
    @@ -82,25 +82,34 @@ def forward_chunkwise(q, k, v, beta, chunk_size=2):
    w, u = decay_values(k_, v_, beta_)
    y = causal_attend(q_, k_, u)

    q_ = q_.view(NH, C, chunk_size, D)
    k_ = k_.view(NH, C, chunk_size, D)
    w = w.view(NH, C, chunk_size, D)
    # stitch chunks sequentially
    y_delta = forward_stitch(q, k, w, u, C=C, chunk_size=chunk_size)
    return y.view(NH, T, D) + y_delta.view(NH, T, D)


    def forward_stitch(q, k, w, u, C, chunk_size):
    "stitch chunks sequentially"
    NH, T, D = shape(q, k, None, None)

    q_ = q.view(NH, C, chunk_size, D)
    k_ = k.view(NH, C, chunk_size, D)
    u = u.view(NH, C, chunk_size, D)
    y = y.view(NH, C, chunk_size, D)
    w = w.view(NH, C, chunk_size, D)

    # stitch chunks sequentially
    state = v.new_zeros(NH, D, D) # NHVK: keys last
    for c in range(C):
    prev_output = einsum('nvk,ntk->ntv', state, q_[:, c])
    y[:, c].add_(prev_output)
    y_delta = q.new_zeros(NH, C, chunk_size, D)

    # materialize the state for the leading chunk
    state = einsum('ntv,ntk->nvk', u[:, 0], k_[:, 0])

    for c in range(1, C):
    prev_output = einsum('nvk,ntk->ntv', state, q_[:, c])
    state_decays = einsum('nvk,ntk->ntv', state, w[:, c])
    y[:, c].sub_(causal_attend(q_[:, c], k_[:, c], state_decays))

    state.sub_(einsum('ntv,ntk->nvk', state_decays, k_[:, c]))
    state.add_(einsum('nti,ntj->nij', u[:, c], k_[:, c]))

    return y.view(NH, T, D)
    y_delta[:, c] = prev_output - causal_attend(q_[:, c], k_[:, c], state_decays)

    state.add_(einsum('ntv,ntk->nvk', u[:, c] - state_decays, k_[:, c]))

    return y_delta


    def forward_ogloop(q, k, v, beta):
    @@ -153,7 +162,8 @@ def shape(q, k, v, beta=None):
    NH, T, D = (q if q is not None else k).shape
    if q is not None:
    assert q.shape == (NH, T, D)
    assert k.shape == v.shape
    if v is not None:
    assert k.shape == v.shape
    if beta is not None:
    assert beta.shape == (NH, T)
    return NH, T, D
  24. proger revised this gist Jul 1, 2024. 1 changed file with 39 additions and 26 deletions.
    65 changes: 39 additions & 26 deletions deltanet.py
    Original file line number Diff line number Diff line change
    @@ -451,10 +451,11 @@ def test_equal_decay_values_backward():


    @no_grad()
    def decay_values_backward_only_u(k, v, beta):
    def decay_values_backward_only_u(d_out, k, v, beta):
    "trimmed version of decay_values_backward that does not use w"

    NH, T, D = shape(None, k, v, beta)
    DV = D
    K = einsum('nsd,ntd->nst', k, k) # (T,T) matrix

    u = v.new_zeros(NH, T, D)
    @@ -469,7 +470,7 @@ def decay_values_backward_only_u(k, v, beta):
    UV[:, t, :t] = einsum('n,nt,ntsd->nsd', -beta[:, t], K[:, :t, t], UV[:, :t, :t])
    UV[:, t, t] = beta[:, t]

    UK = u.new_zeros(NH, T, S, D) # d u_t / d k_s
    UK = u.new_zeros(NH, T, S, D, DV) # d u_t / d k_s
    for t in range(T):
    # [s<t]
    for s in range(t):
    @@ -478,10 +479,10 @@ def decay_values_backward_only_u(k, v, beta):
    if s < l:
    UK[:, t, s] -= beta[:, t] * c * UK[:, l, s]
    if s == l:
    UK[:, t, s] -= beta[:, t] * einsum('nd,nh->nd', k[:, t], u[:, s])
    UK[:, t, s] -= beta[:, t] * einsum('nk,nv->nkv', k[:, t], u[:, s])
    UK[:, t, s] -= beta[:, t] * c * UK[:, l, l]
    # [s=t]
    UK[:, t, t] -= beta[:, t] * einsum('nsd,nsh->nd', k[:, :t], u[:, :t])
    UK[:, t, t] -= beta[:, t] * einsum('nsk,nsv->nkv', k[:, :t], u[:, :t])

    UB = u.new_zeros(NH, T, S, D) # d u_t / d beta_s

    @@ -493,51 +494,63 @@ def decay_values_backward_only_u(k, v, beta):

    UB[:, t, t] = v[:, t] - einsum('ns,nsd->nd', K[:, :t, t], u[:, :t])

    d_beta = einsum('ntsd->ns', UB) # sum T and D out
    d_k = einsum('ntsd->nsd', UK) # sum T out
    d_v = einsum('ntsd->nsd', UV) # sum T out
    d_beta = einsum('ntv,ntsv->ns', d_out, UB)
    d_k = einsum('ntv,ntskv->nsk', d_out, UK)
    d_v = einsum('ntv,ntsv->nsv', d_out, UV)

    return u, d_k, d_v, d_beta


    def simple_backward(d_out, q, k, v, beta):
    u, du_dk, du_dv, du_dbeta = decay_values_backward_only_u(k, v, beta)
    d_q, d_k1, d_u = causal_attend_backward(d_out, q, k, u)
    d_k = d_k1 + d_u * du_dk
    d_v = d_out * d_u * du_dv
    d_beta = (d_out * d_u * du_dbeta.unsqueeze(-1)).sum(-1)
    def forward_simple1(q, k, v, beta):
    NH, T, D = shape(q, k, v, beta)
    K = einsum('nsd,ntd->nst', k, k) # (T,T) matrix

    u = einsum('nt,ntd->ntd', beta, v.clone())
    for t in range(1,T):
    u[:, t] -= einsum('n,nt,ntd->nd', beta[:,t], K[:, :t, t], u[:, :t].clone())

    return d_q, d_k, d_v, d_beta
    return u, causal_attend(q, k, u)


    class DeltaSimple(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v, beta):
    ctx.save_for_backward(q, k, v, beta)
    w, u = decay_values(k, v, beta)
    return causal_attend(q, k, u)
    u, y = forward_simple1(q, k, v, beta)
    ctx.save_for_backward(q, k, v, beta, u)
    return y

    @staticmethod
    def backward(ctx, d_out):
    q, k, v, beta = ctx.saved_tensors
    return simple_backward(d_out, q, k, v, beta)
    q, k, v, beta, u = ctx.saved_tensors

    d_q, d_k1, d_u = causal_attend_backward(d_out, q, k, u)
    u, d_k, d_v, d_beta = decay_values_backward_only_u(d_u, k, v, beta)
    d_k = d_k1 + d_k

    return d_q, d_k, d_v, d_beta


    def test_delta_simple_backward():
    NH, T, D = 1, 16, 2
    NH, T, D = 1, 3, 2
    q1, k1, v1, beta1 = make_example(NH, T, D)
    y1 = forward_simple(q1, k1, v1, beta1)
    u1, y1 = forward_simple1(q1, k1, v1, beta1)
    (y1 - torch.ones_like(y1).detach()).pow(2).mean().backward()

    q, k, v, beta = make_example(NH, T, D)
    y = DeltaSimple.apply(q, k, v, beta)
    (y - torch.ones_like(y).detach()).pow(2).mean().backward()

    # print(beta1.grad - beta.grad, 'beta.grad diff')
    # print(k1.grad - k.grad, 'k.grad diff')
    # print(v1.grad - v.grad, 'v.grad diff')

    assert allclose(y1, y, atol=1e-5), 'y is wrong'

    print(beta1.grad - beta.grad, 'beta.grad diff')
    #print(q1.grad - q.grad, 'q.grad diff')
    print(k1.grad - k.grad, 'k.grad diff')
    print(v1.grad - v.grad, 'v.grad diff')

    assert allclose(q1.grad, q.grad, atol=1e-5), 'q.grad is wrong'
    assert allclose(beta1.grad, beta.grad, atol=1e-5), 'beta.grad is wrong'
    assert allclose(k1.grad, k.grad, atol=1e-5), 'k.grad is wrong'
    assert allclose(v1.grad, v.grad, atol=1e-5), 'v.grad is wrong'


    test_delta_simple_backward()
  25. proger revised this gist Jul 1, 2024. 1 changed file with 95 additions and 0 deletions.
    95 changes: 95 additions & 0 deletions deltanet.py
    Original file line number Diff line number Diff line change
    @@ -446,3 +446,98 @@ def test_equal_decay_values_backward():
    assert allclose(beta.grad, beta_grad, atol=1e-5), 'beta_grad wrt u is wrong'

    test_equal_decay_values_backward()

    #%%


    @no_grad()
    def decay_values_backward_only_u(k, v, beta):
    "trimmed version of decay_values_backward that does not use w"

    NH, T, D = shape(None, k, v, beta)
    K = einsum('nsd,ntd->nst', k, k) # (T,T) matrix

    u = v.new_zeros(NH, T, D)
    u[:, 0] = beta[:, 0] * v[:, 0]
    for t in range(1,T):
    u[:, t] = beta[:, t] * v.clone()[:, t] - beta[:, t] * einsum('nsj,nj,nsd->nd', k[:, :t], k[:, t], u[:, :t].clone())

    S = T

    UV = u.new_zeros(NH, T, S, D) # d u_t / d v_s
    for t in range(T):
    UV[:, t, :t] = einsum('n,nt,ntsd->nsd', -beta[:, t], K[:, :t, t], UV[:, :t, :t])
    UV[:, t, t] = beta[:, t]

    UK = u.new_zeros(NH, T, S, D) # d u_t / d k_s
    for t in range(T):
    # [s<t]
    for s in range(t):
    for l in range(t):
    c = (k[:, t] * k[:, l]).sum(-1)
    if s < l:
    UK[:, t, s] -= beta[:, t] * c * UK[:, l, s]
    if s == l:
    UK[:, t, s] -= beta[:, t] * einsum('nd,nh->nd', k[:, t], u[:, s])
    UK[:, t, s] -= beta[:, t] * c * UK[:, l, l]
    # [s=t]
    UK[:, t, t] -= beta[:, t] * einsum('nsd,nsh->nd', k[:, :t], u[:, :t])

    UB = u.new_zeros(NH, T, S, D) # d u_t / d beta_s

    for t in range(T):
    for s in range(t):
    for l in range(t):
    c = (k[:, t] * k[:, l]).sum(-1)
    UB[:, t, s] -= beta[:, t] * c * UB[:, l, s]

    UB[:, t, t] = v[:, t] - einsum('ns,nsd->nd', K[:, :t, t], u[:, :t])

    d_beta = einsum('ntsd->ns', UB) # sum T and D out
    d_k = einsum('ntsd->nsd', UK) # sum T out
    d_v = einsum('ntsd->nsd', UV) # sum T out

    return u, d_k, d_v, d_beta


    def simple_backward(d_out, q, k, v, beta):
    u, du_dk, du_dv, du_dbeta = decay_values_backward_only_u(k, v, beta)
    d_q, d_k1, d_u = causal_attend_backward(d_out, q, k, u)

    d_k = d_k1 + d_u * du_dk
    d_v = d_out * d_u * du_dv
    d_beta = (d_out * d_u * du_dbeta.unsqueeze(-1)).sum(-1)

    return d_q, d_k, d_v, d_beta


    class DeltaSimple(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v, beta):
    ctx.save_for_backward(q, k, v, beta)
    w, u = decay_values(k, v, beta)
    return causal_attend(q, k, u)

    @staticmethod
    def backward(ctx, d_out):
    q, k, v, beta = ctx.saved_tensors
    return simple_backward(d_out, q, k, v, beta)


    def test_delta_simple_backward():
    NH, T, D = 1, 16, 2
    q1, k1, v1, beta1 = make_example(NH, T, D)
    y1 = forward_simple(q1, k1, v1, beta1)
    (y1 - torch.ones_like(y1).detach()).pow(2).mean().backward()

    q, k, v, beta = make_example(NH, T, D)
    y = DeltaSimple.apply(q, k, v, beta)
    (y - torch.ones_like(y).detach()).pow(2).mean().backward()

    # print(beta1.grad - beta.grad, 'beta.grad diff')
    # print(k1.grad - k.grad, 'k.grad diff')
    # print(v1.grad - v.grad, 'v.grad diff')

    assert allclose(y1, y, atol=1e-5), 'y is wrong'

    test_delta_simple_backward()
  26. proger revised this gist Jul 1, 2024. 1 changed file with 50 additions and 24 deletions.
    74 changes: 50 additions & 24 deletions deltanet.py
    Original file line number Diff line number Diff line change
    @@ -159,14 +159,23 @@ def shape(q, k, v, beta=None):
    return NH, T, D


    def test_equal(atol=1e-6):
    def make_example(NH, T, D):
    manual_seed(0)
    NH, T, D = 2*3, 128, 16
    #NH, T, D = 1, 8, 3
    q = randn(NH, T, D) / D**0.5
    q.requires_grad_()
    k = randn(NH, T, D) / D**0.5
    k.requires_grad_()
    v = randn(NH, T, D) / D**0.5
    v.requires_grad_()
    beta = randn(NH, T).sigmoid()
    beta.requires_grad_()
    return q, k, v, beta


    def test_equal(atol=1e-6):
    NH, T, D = 2*3, 128, 16
    #NH, T, D = 1, 8, 3
    q, k, v, beta = make_example(NH, T, D)

    y1 = forward_ogloop(q, k, v, beta)
    y2 = forward_scanloop(q, k, v, beta)
    @@ -202,16 +211,21 @@ def causal_attend_backward(d_out, q, k, v, diagonal=0):
    return d_q, q_k, d_v


    def test_equal_attend_backward(atol=1e-5):
    manual_seed(0)
    class CausalAttend(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v):
    ctx.save_for_backward(q, k, v)
    return causal_attend(q, k, v)

    @staticmethod
    def backward(ctx, d_out):
    q, k, v = ctx.saved_tensors
    return causal_attend_backward(d_out, q, k, v)


    def test_equal_attend_backward(atol=1e-5):
    NH, T, D = 1*1, 512, 64
    q = randn(NH, T, D) / D**0.5 / D**0.25
    q.requires_grad_()
    k = (randn(NH, T, D) / D**0.5 / D**0.25).sigmoid()
    k.requires_grad_()
    v = randn(NH, T, D) / D**0.5
    v.requires_grad_()
    q, k, v, beta = make_example(NH, T, D)

    y = causal_attend(q, k, v)
    d_q, d_k, d_v = causal_attend_backward(torch.ones_like(y), q, k, v)
    @@ -227,7 +241,32 @@ def test_equal_attend_backward(atol=1e-5):
    # assert torch.allclose(g_hook.grad, d_g, atol=1e-1), 'g.grad is wrong'


    def test_equal_attend_backward2(atol=1e-5):
    NH, T, D = 1, 2, 2
    q1, k1, v1, beta1 = make_example(NH, T, D)
    y1 = causal_attend(q1, k1, v1)
    (y1 - torch.ones_like(y1)).pow(2).mean().backward()

    q, k, v, beta = make_example(NH, T, D)
    y = CausalAttend.apply(q, k, v)
    (y - torch.ones_like(y)).pow(2).mean().backward()

    # print(q1.grad - q.grad, 'q.grad diff')
    # print(k1.grad - k.grad, 'k.grad diff')
    # print(v1.grad - v.grad, 'v.grad diff')
    # print(k.grad, 'k.grad')
    # print(k1.grad, 'k1.grad')
    # print(v.grad, 'v.grad')
    # print(v1.grad, 'v1.grad')

    assert (q1.grad - q.grad).abs().max() < atol, 'q.grad is wrong'
    assert (k1.grad - k.grad).abs().max() < atol, 'k.grad is wrong'
    assert (v1.grad - v.grad).abs().max() < atol, 'v.grad is wrong'


    test_equal_attend_backward()
    test_equal_attend_backward2()


    #%%

    @@ -372,19 +411,6 @@ def decay_values_backward(k, v, beta):
    return w, u, d_k, d_v, d_beta


    def make_example(NH, T, D):
    manual_seed(0)
    q = randn(NH, T, D) / D**0.5
    q.requires_grad_()
    k = randn(NH, T, D) / D**0.5
    k.requires_grad_()
    v = randn(NH, T, D) / D**0.5
    v.requires_grad_()
    beta = randn(NH, T).sigmoid()
    beta.requires_grad_()
    return q, k, v, beta


    def test_equal_decay_values_backward():
    NH, T, D = 1, 4, 2

  27. proger revised this gist Jul 1, 2024. 1 changed file with 31 additions and 24 deletions.
    55 changes: 31 additions & 24 deletions deltanet.py
    Original file line number Diff line number Diff line change
    @@ -34,6 +34,7 @@

    import os
    os.environ['TORCH_LOGS'] = 'output_code'
    import torch
    from torch import einsum, randn, allclose, stack, eye, manual_seed, no_grad, set_float32_matmul_precision, compile, arange

    set_float32_matmul_precision('high')
    @@ -49,8 +50,8 @@ def decay_values(k, v, beta):
    K = einsum('nsd,ntd->nst', k, k) # (T,T) matrix

    for t in range(1,T):
    w[:, t] -= beta_[:, t] * einsum('nt,ntd->nd', K[:, :t, t], w[:, :t])
    u[:, t] -= beta_[:, t] * einsum('nt,ntd->nd', K[:, :t, t], u[:, :t])
    w[:, t] -= beta_[:, t] * einsum('nt,ntd->nd', K[:, :t, t], w[:, :t].clone())
    u[:, t] -= beta_[:, t] * einsum('nt,ntd->nd', K[:, :t, t], u[:, :t].clone())

    return w, u

    @@ -185,19 +186,19 @@ def test_equal(atol=1e-6):


    @no_grad()
    def attend_backward(q, k, v, g):
    d_q = einsum('ntk,ntv,nst->nsk', k, v, g)
    d_k = einsum('nsk,ntv,nst->ntk', q, v, g)
    d_v = einsum('nsk,ntk,nst->nt', q, k, g).unsqueeze(-1).expand_as(v)
    d_g = einsum('nsk,ntk,ntv,nst->nst', q, k, v, g)
    def attend_backward(d_out, q, k, v, g):
    d_q = einsum('nsv,ntk,ntv,nst->nsk', d_out, k, v, g)
    d_k = einsum('nsv,nsk,ntv,nst->ntk', d_out, q, v, g)
    d_v = einsum('nsv,nsk,ntk,nst->ntv', d_out, q, k, g)
    d_g = einsum('nsv,nsk,ntk,ntv,nst->nst', d_out, q, k, v, g)
    return d_q, d_k, d_v, d_g


    @no_grad()
    def causal_attend_backward(q, k, v, diagonal=0):
    def causal_attend_backward(d_out, q, k, v, diagonal=0):
    NH, T, D = shape(q, k, v)
    mask = q.new_ones(T, T).tril(diagonal=diagonal).unsqueeze(0)
    d_q, q_k, d_v, _d_mask = attend_backward(q, k, v, mask)
    d_q, q_k, d_v, _d_mask = attend_backward(d_out, q, k, v, mask)
    return d_q, q_k, d_v


    @@ -213,7 +214,7 @@ def test_equal_attend_backward(atol=1e-5):
    v.requires_grad_()

    y = causal_attend(q, k, v)
    d_q, d_k, d_v = causal_attend_backward(q, k, v)
    d_q, d_k, d_v = causal_attend_backward(torch.ones_like(y), q, k, v)
    y.sum().backward()

    assert allclose(q.grad, d_q, atol=atol), 'q.grad is wrong'
    @@ -353,7 +354,7 @@ def decay_values_backward(k, v, beta):
    # [s=t]
    UK[:, t, t] -= beta[:, t] * einsum('nsd,nsh->nd', k[:, :t], u[:, :t])

    UB = w.new_zeros(NH, T, S, D) # d u_t / d beta_s
    UB = u.new_zeros(NH, T, S, D) # d u_t / d beta_s

    for t in range(T):
    for s in range(t):
    @@ -368,36 +369,42 @@ def decay_values_backward(k, v, beta):

    d_v = einsum('ntsd->nsd', UV) # sum T out

    return d_k, d_v, d_beta
    return w, u, d_k, d_v, d_beta


    def make_example(NH, T, D):
    manual_seed(0)
    q = randn(NH, T, D) / D**0.5
    q.requires_grad_()
    k = randn(NH, T, D) / D**0.5
    k.requires_grad_()
    v = randn(NH, T, D) / D**0.5
    v.requires_grad_()
    beta = randn(NH, T).sigmoid()
    beta.requires_grad_()
    return q, k, v, beta


    def test_equal_decay_values_backward():
    NH, T, D = 1, 4, 2

    def make():
    manual_seed(0)
    k = randn(NH, T, D) / D**0.5
    k.requires_grad_()
    v = randn(NH, T, D) / D**0.5
    v.requires_grad_()
    beta = randn(NH, T).sigmoid()
    beta.requires_grad_()
    return k, v, beta
    q, k, v, beta = make_example(NH, T, D)

    ## to try these modify summations in d_beta and d_k above
    # k, v, beta = make()
    # w, u = decay_values_forward(k, v, beta)
    # w.sum().backward()
    # k_grad, v_grad, beta_grad = decay_values_backward(k, v, beta)
    # w, u, k_grad, v_grad, beta_grad = decay_values_backward(k, v, beta)
    # # print(k.grad, 'k.grad')
    # # print(k_grad, 'k_grad')
    # # print(k.grad - k_grad, 'diff')
    # assert allclose(k.grad, k_grad, atol=1e-5), 'k_grad is wrong'
    # assert allclose(beta.grad, beta_grad, atol=1e-5), 'beta_grad is wrong'

    k, v, beta = make()
    q, k, v, beta = make_example(NH, T, D)
    w, u = decay_values_forward(k, v, beta)
    (w + u).sum().backward()
    k_grad, v_grad, beta_grad = decay_values_backward(k, v, beta)
    w, u, k_grad, v_grad, beta_grad = decay_values_backward(k, v, beta)
    # print(v.grad, 'v.grad', v.grad.shape)
    # print(v_grad, 'v_grad')
    # print(v.grad - v_grad, 'v diff')
  28. proger revised this gist Jun 30, 2024. 1 changed file with 67 additions and 20 deletions.
    87 changes: 67 additions & 20 deletions deltanet.py
    Original file line number Diff line number Diff line change
    @@ -289,16 +289,15 @@ def decay_values_backward(k, v, beta):
    """

    S = T

    WB = w.new_zeros(NH, T, S, D) # d w_t / d beta_s
    UB = w.new_zeros(NH, T, S, D) # d u_t / d beta_s # TODO: UB
    WK = w.new_zeros(NH, T, S, D) # d w_t / d k_s
    UV = u.new_zeros(NH, T, S, D) # d u_t / d v_s
    UK = u.new_zeros(NH, T, S, D) # d u_t / d k_s # TODO: UK

    for t in range(T):
    WB[:, t, :t] = einsum('n,nt,ntsd->nsd', -beta[:, t], K[:, :t, t], WB[:, :t, :t])
    WB[:, t, t] = k[:, t] - einsum('nt,ntd->nd', K[:, :t, t], w[:, :t])

    WK = w.new_zeros(NH, T, S, D) # d w_t / d k_s

    for t in range(T):
    # [s<t]
    for s in range(t):
    @@ -320,15 +319,53 @@ def decay_values_backward(k, v, beta):
    + [s<t] - b_t \sum_{j=0}^{t-1} k_j^T k_t (d u_j / d v_s)
    """

    UV = u.new_zeros(NH, T, S, D) # d u_t / d v_s

    for t in range(T):
    # [s<t]
    UV[:, t, :t] = einsum('n,nt,ntsd->nsd', -beta[:, t], K[:, :t, t], UV[:, :t, :t])

    # [s=t]
    UV[:, t, t] = beta[:, t]

    d_beta = einsum('ntsd->ns', WB) # sum T and D out
    d_k = einsum('ntsd->nsd', WK) # sum T out
    """
    d u_t / d k_s =
    - b_t \sum_{l=0}^{t-1} (D_{k_s} k_l^T k_t u_l)
    D_{k_s} k_l^T k_t u_l = [s=t] k_l u_l^T # sum out dimensions of u_j after the outer product?
    D_{k_s} k_l^T k_t u_l = [s<t][s<l]: only u_l is a function of k_s
    [s<t][s=l]: product rule for u_l(k_s) and k_l=k_s
    """

    UK = u.new_zeros(NH, T, S, D) # d u_t / d k_s

    for t in range(T):
    # [s<t]
    for s in range(t):
    for l in range(t):
    c = (k[:, t] * k[:, l]).sum(-1)
    if s < l:
    UK[:, t, s] -= beta[:, t] * c * UK[:, l, s]
    if s == l:
    UK[:, t, s] -= beta[:, t] * einsum('nd,nh->nd', k[:, t], u[:, s])
    UK[:, t, s] -= beta[:, t] * c * UK[:, l, l]

    # [s=t]
    UK[:, t, t] -= beta[:, t] * einsum('nsd,nsh->nd', k[:, :t], u[:, :t])

    UB = w.new_zeros(NH, T, S, D) # d u_t / d beta_s

    for t in range(T):
    for s in range(t):
    for l in range(t):
    c = (k[:, t] * k[:, l]).sum(-1)
    UB[:, t, s] -= beta[:, t] * c * UB[:, l, s]

    UB[:, t, t] = v[:, t] - einsum('ns,nsd->nd', K[:, :t, t], u[:, :t])

    d_beta = einsum('ntsd->ns', WB) + einsum('ntsd->ns', UB) # sum T and D out
    d_k = einsum('ntsd->nsd', WK) + einsum('ntsd->nsd', UK) # sum T out

    d_v = einsum('ntsd->nsd', UV) # sum T out

    return d_k, d_v, d_beta
    @@ -346,23 +383,33 @@ def make():
    beta.requires_grad_()
    return k, v, beta

    k, v, beta = make()
    w, u = decay_values_forward(k, v, beta)
    w.sum().backward()
    k_grad, v_grad, beta_grad = decay_values_backward(k, v, beta)
    # print(k.grad, 'k.grad')
    # print(k_grad, 'k_grad')
    # print(k.grad - k_grad, 'diff')
    assert allclose(k.grad, k_grad, atol=1e-5), 'k_grad is wrong'
    assert allclose(beta.grad, beta_grad, atol=1e-5), 'beta_grad is wrong'
    ## to try these modify summations in d_beta and d_k above
    # k, v, beta = make()
    # w, u = decay_values_forward(k, v, beta)
    # w.sum().backward()
    # k_grad, v_grad, beta_grad = decay_values_backward(k, v, beta)
    # # print(k.grad, 'k.grad')
    # # print(k_grad, 'k_grad')
    # # print(k.grad - k_grad, 'diff')
    # assert allclose(k.grad, k_grad, atol=1e-5), 'k_grad is wrong'
    # assert allclose(beta.grad, beta_grad, atol=1e-5), 'beta_grad is wrong'

    k, v, beta = make()
    w, u = decay_values_forward(k, v, beta)
    u.sum().backward()
    (w + u).sum().backward()
    k_grad, v_grad, beta_grad = decay_values_backward(k, v, beta)
    print(v.grad, 'v.grad', v.grad.shape)
    print(v_grad, 'v_grad')
    print(v.grad - v_grad, 'v diff')
    # print(v.grad, 'v.grad', v.grad.shape)
    # print(v_grad, 'v_grad')
    # print(v.grad - v_grad, 'v diff')
    assert allclose(v.grad, v_grad, atol=1e-5), 'v_grad is wrong'

    test_equal_decay_values_backward()
    # print(k.grad, 'k.grad du')
    # print(k_grad, 'k_grad du')
    # print(k.grad - k_grad, 'diff du')
    assert allclose(k.grad, k_grad, atol=1e-5), 'k_grad wrt u is wrong'

    # print(beta.grad, 'beta.grad du')
    # print(beta_grad, 'beta_grad du')
    assert allclose(beta.grad, beta_grad, atol=1e-5), 'beta_grad wrt u is wrong'

    test_equal_decay_values_backward()
  29. proger revised this gist Jun 30, 2024. 1 changed file with 25 additions and 9 deletions.
    34 changes: 25 additions & 9 deletions deltanet.py
    Original file line number Diff line number Diff line change
    @@ -261,6 +261,8 @@ def decay_values_backward(k, v, beta):
    w_1 = b_1 k_1 - b_1 k_0^T k_1 w_0
    w_2 = b_2 k_2 - b_2 k_0^T k_2 w_0 - b_2 k_1^T k_2 w_1
    w_t = b_t k_t - b_t \sum_{s=0}^{t-1} k_s^T k_t w_s
    u_t = b_t v_t - b_t \sum_{s=0}^{t-1} k_s^T k_t u_s
    """
    w[:, 0] = beta[:, 0] * k[:, 0]
    u[:, 0] = beta[:, 0] * v[:, 0]
    @@ -311,7 +313,21 @@ def decay_values_backward(k, v, beta):
    WK[:, t, t] = beta[:, t]
    WK[:, t, t] += -beta[:, t] * einsum('nsj,nsi->nj', k[:, :t], w[:, :t])

    d_beta = WB.sum(1).sum(-1) # sum T out and then D # TODO: d u_t / d beta_s
    """
    u_t = b_t v_t - b_t \sum_{s=0}^{t-1} k_s^T k_t u_s
    d u_t / d v_s = [s=t] b_t
    + [s<t] - b_t \sum_{j=0}^{t-1} k_j^T k_t (d u_j / d v_s)
    """

    for t in range(T):
    # [s<t]
    UV[:, t, :t] = einsum('n,nt,ntsd->nsd', -beta[:, t], K[:, :t, t], UV[:, :t, :t])

    # [s=t]
    UV[:, t, t] = beta[:, t]

    d_beta = einsum('ntsd->ns', WB) # sum T and D out
    d_k = einsum('ntsd->nsd', WK) # sum T out
    d_v = einsum('ntsd->nsd', UV) # sum T out

    @@ -340,13 +356,13 @@ def make():
    assert allclose(k.grad, k_grad, atol=1e-5), 'k_grad is wrong'
    assert allclose(beta.grad, beta_grad, atol=1e-5), 'beta_grad is wrong'

    # k, v, beta = make()
    # w, u = decay_values_forward(k, v, beta)
    # u.sum().backward()
    # k_grad, v_grad, beta_grad = decay_values_backward(k, v, beta)
    # print(v.grad, 'v.grad', v.grad.shape)
    # print(v_grad, 'v_grad')
    # print(v.grad - v_grad, 'v diff')
    # assert allclose(v.grad, v_grad, atol=1e-5), 'v_grad is wrong'
    k, v, beta = make()
    w, u = decay_values_forward(k, v, beta)
    u.sum().backward()
    k_grad, v_grad, beta_grad = decay_values_backward(k, v, beta)
    print(v.grad, 'v.grad', v.grad.shape)
    print(v_grad, 'v_grad')
    print(v.grad - v_grad, 'v diff')
    assert allclose(v.grad, v_grad, atol=1e-5), 'v_grad is wrong'

    test_equal_decay_values_backward()
  30. proger revised this gist Jun 30, 2024. 1 changed file with 6 additions and 6 deletions.
    12 changes: 6 additions & 6 deletions deltanet.py
    Original file line number Diff line number Diff line change
    @@ -34,7 +34,7 @@

    import os
    os.environ['TORCH_LOGS'] = 'output_code'
    from torch import einsum, randn, allclose, stack, eye, manual_seed, no_grad, set_float32_matmul_precision, compile
    from torch import einsum, randn, allclose, stack, eye, manual_seed, no_grad, set_float32_matmul_precision, compile, arange

    set_float32_matmul_precision('high')

    @@ -300,12 +300,12 @@ def decay_values_backward(k, v, beta):
    for t in range(T):
    # [s<t]
    for s in range(t):
    # for j in range(t)
    # [s<j]
    # for j in range(t): [s<j]
    WK[:, t, s] -= beta[:, t] * einsum('nj,njd->nd', K[:, s+1:t, t], WK[:, s+1:t, s])
    # [s=j]
    WK[:, t, s] -= beta[:, t] * einsum('nh,nd->nh', k[:, t], w[:, s])
    WK[:, t, s] -= beta[:, t] * einsum('n,nd->nd', K[:, s, t], WK[:, s, s])

    # for s in range(t): for j in range(t): [s<t][s=j]
    WK[:, t, :t] -= beta[:, t] * einsum('nh,nsd->nsh', k[:, t], w[:, :t])
    WK[:, t, :t] -= beta[:, t] * einsum('ns,nsd->nsd', K[:, :t, t], WK[:, arange(t), arange(t)])

    # [s=t]
    WK[:, t, t] = beta[:, t]