Last active
October 8, 2024 04:13
-
-
Save proger/0a04b2168f1110636c720ba204b5ac2d to your computer and use it in GitHub Desktop.
Revisions
-
proger revised this gist
Jul 16, 2024 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -452,7 +452,7 @@ def backward(ctx, d_y): def test_delta(): NH, T, D = 1, 64, 16 q1, k1, v1, beta1 = make_example(NH, T, D) y0 = forward_loop(q1, k1, v1, beta1) -
proger revised this gist
Jul 16, 2024 . 1 changed file with 150 additions and 121 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,13 +1,8 @@ """ DeltaNet implementation reference for Accelerated Scan. DeltaNet performs efficient management of a large fixed-sized memory. `forward` is inspired by Yang 2024. It applies single chunk version pointwise and then performs chunk-level stitching. `forward_loop` is the reference implementation of the original recurrence. References: [1] The WY Representation for Products of Householder Matrices (Bischof and Van Loan 1985) @@ -63,78 +58,82 @@ def fmt(r,c,tag): ])) def decay_values(q, k, v, beta, chunk_size=2): NH, T, D = shape(q, k, v, beta) C = T // chunk_size q_, k_, v_, beta_ = ( q.view(NH*C, chunk_size, D), k.view(NH*C, chunk_size, D), v.view(NH*C, chunk_size, D), beta.view(NH*C, chunk_size) ) # evaluate all chunks in parallel beta__ = beta_.unsqueeze(-1) w = beta__ * k_.clone() u = beta__ * v_.clone() K = einsum('nsd,ntd->nst', k_, k_) # (chunk_size,chunk_size) matrix for t in range(1,chunk_size): w[:, t] -= beta__[:, t] * einsum('nt,ntd->nd', K[:, :t, t], w[:, :t].clone()) u[:, t] -= beta__[:, t] * einsum('nt,ntd->nd', K[:, :t, t], u[:, :t].clone()) # attend to decayed values qk = einsum("nsk,ntk->nst", q_, k_) qk.tril_() y = einsum("nst,ntj->nsj", qk, u) return w, u, y def forward(q, k, v, beta, chunk_size=2): "decay values applying deltanet forgetting rules, then stitch chunks" NH, T, D = shape(q, k, v, beta) C = T // chunk_size w, u, y = decay_values(q, k, v, beta, chunk_size=chunk_size) # stitch chunks sequentially q_ = q.view(NH, C, chunk_size, D) k_ = k.view(NH, C, chunk_size, D) u = u.view(NH, C, chunk_size, D) w = w.view(NH, C, chunk_size, D) y = y.view(NH, C, chunk_size, D) # materialize the state for the leading chunk kc = k_[:, 0] uc = u[:, 0] state = u.new_zeros(NH, D, D) for c in range(1, C): state = state + einsum('ntv,ntk->nvk', uc, kc) wc = w[:, c] # load w uc = einsum('ntk,nvk->ntv', wc, state) # DDT qc = q_[:, c] # load q kc = k_[:, c] # load k # attend to old values qk = einsum("nsi,nti->nst", qc, kc) # TDT qk = qk.tril() yc = y[:, c].clone() # load y y_prev = einsum("nst,ntv->nsv", qk, uc) # TTD yc = yc - y_prev y_cur = einsum('nsk,nvk->nsv', qc, state) # DDT yc = yc + y_cur y[:, c] = yc # store u1 = u[:, c] # load u uc = u1 - uc w = w.view(NH, T, D) u = u.view(NH, T, D) y = y.view(NH, T, D) return w, u, y @@ -186,8 +185,15 @@ def make_example(NH, T, D, device='cpu', dtype=torch.float32): @no_grad() def backward(d_out_w_long, d_out_u_long, d_out_y_long, q_long, k_long, v_long, beta_long, chunk_size=2): NH, T, D = shape(q_long, k_long, v_long, beta_long) C = T // chunk_size q, k, v, beta, d_out_y = ( q_long.view(NH*C, chunk_size, D), k_long.view(NH*C, chunk_size, D), v_long.view(NH*C, chunk_size, D), beta_long.view(NH*C, chunk_size), d_out_y_long.view(NH*C, chunk_size, D) ) # # allocations @@ -198,21 +204,21 @@ def decay_values_backward(d_out_w, d_out_u, d_out_y, q, k, v, beta): k = k.clone() # load k v = v.clone() # load v beta = beta.clone() # load beta #d_out_w = d_out_w.clone() # ntk # placeholders #d_out_y = d_out_y.clone() # ntv # placeholders w = k.new_zeros(NH*C, chunk_size, D) # ntk u = v.new_zeros(NH*C, chunk_size, D) # ntw w_bases = w.clone() # ntk u_bases = u.clone() # ntw bk = einsum('nt,ntk->ntk', beta, k) bKl = k.new_zeros(NH*C, chunk_size, chunk_size) tt = k.new_zeros(NH*C, chunk_size, chunk_size) d_k = k.new_zeros(NH*C, chunk_size, D) # nsk tk = k.new_zeros(NH*C, chunk_size, D) # ntk # # forward @@ -225,7 +231,7 @@ def decay_values_backward(d_out_w, d_out_u, d_out_y, q, k, v, beta): u_bases = v v = einsum('nt,ntw->ntw', beta, v) for t in range(chunk_size): tk = einsum('nts,nsk->ntk', bKl, w) # matmul for the sake of one row w[:, t] = bk[:, t, :] - tk[:, t, :] tk = einsum('nts,nsw->ntw', bKl, u) # matmul for the sake of one row @@ -234,6 +240,17 @@ def decay_values_backward(d_out_w, d_out_u, d_out_y, q, k, v, beta): w.clone() # store w u.clone() # store u # # stitch_backward # w_long = w.view(NH, T, D) u_long = u.view(NH, T, D) d_q_1_long, d_k_1_long, d_out_w_long, d_out_u_long = stitch_backward(d_out_y_long, q_long, k_long, w_long, u_long, C, chunk_size) d_out_w, d_out_u = ( d_out_w_long.view(NH*C, chunk_size, D), d_out_u_long.view(NH*C, chunk_size, D) ) w_bases = einsum('nts,nsk->ntk', tt, w) w_bases = k - w_bases v = einsum('nts,nsw->ntw', tt, u) @@ -264,7 +281,7 @@ def decay_values_backward(d_out_w, d_out_u, d_out_y, q, k, v, beta): d_k.zero_() for t in range(chunk_size-1,-1,-1): # d_k tt = einsum('njw,ntw->njt', w, d_out_w) # matmul for the sake of one column t tt[:, t:, :] = 0 @@ -311,132 +328,143 @@ def decay_values_backward(d_out_w, d_out_u, d_out_y, q, k, v, beta): beta += einsum('ntv->nt', u_bases) d_beta = beta.clone() # store d_q_long = d_q.view(NH, T, D) + d_q_1_long d_k_long = d_k.view(NH, T, D) + d_k_1_long d_v_long = d_v.view(NH, T, D) d_beta_long = d_beta.view(NH, T) return d_q_long, d_k_long, d_v_long, d_beta_long def stitch_backward(d_y_delta, q, k, w, u, C, chunk_size): NH, T, D = shape(q, k, None, None) # outputs d_q_ = q.new_zeros(NH, C, chunk_size, D) d_k_ = k.new_zeros(NH, C, chunk_size, D) d_w = w.new_zeros(NH, C, chunk_size, D) d_u = u.new_zeros(NH, C, chunk_size, D) # chunked inputs d_y_delta = d_y_delta.view(NH, C, chunk_size, D) q_ = q.view(NH, C, chunk_size, D) k_ = k.view(NH, C, chunk_size, D) w = w.view(NH, C, chunk_size, D) # shared memory copy u = u.view(NH, C, chunk_size, D).clone() state = w.new_zeros(NH, D, D) d_state = w.new_zeros(NH, D, D) # NHVK state_delta = w.new_zeros(NH, D, D) # NHVK # can this be float32? qk = k.new_zeros(NH, chunk_size, C) tk = k.new_zeros(NH, chunk_size, D) # materialize the state for the leading chunk state = einsum('ntv,ntk->nvk', u[:, 0], k_[:, 0]) # stitch forward for c in range(1, C): tk = einsum('nvk,ntk->ntv', state, w[:, c]) u[:, c] = u[:, c] - tk state_delta = einsum('ntv,ntk->nvk', u[:, c], k_[:, c]) if c < C-1: state = state + state_delta # walk the state forwards # from now on, u's are decayed # stitch backward for c in range(C-1, 0, -1): if c < C-1: state_delta = einsum('ntv,ntk->nvk', u[:, c], k_[:, c]) state = state - state_delta # uncompute the state backwards tk = einsum('nvk,ntk->ntv', state, w[:, c]) # state_decay d_y_delta_c = d_y_delta[:, c] d_y_delta_c = -d_y_delta_c # neg # d_q, d_k qk = einsum('nsv,ntv->nst', d_y_delta_c, tk) qk.tril_() # d_q tk = einsum('nst,ntk->nsk', qk, k_[:, c]) # causal_attend_backward for delta tk.sub_(einsum('nsv,nvk->nsk', d_y_delta_c, state)) # prev_output d_q_[:, c] = tk # d_k tk = einsum('nst,nsk->ntk', qk, q_[:, c]) if c < C-1: tk.add_(einsum('nvk,ntv->ntk', d_state, u[:, c])) # state_add else: # d_state is zero pass d_k_[:, c] = tk # d_u if c < C-1: d_u[:, c] = einsum('nvk,ntk->ntv', d_state, k_[:, c]) # state_add else: # d_state is zero pass # d_state_decays qk = einsum('nsk,ntk->nst', q_[:, c], k_[:, c]) qk.tril_() d_state_decays = einsum('nsv,nst->ntv', d_y_delta_c, qk) if c < C-1: d_state_decays.sub_(einsum('nvk,ntk->ntv', d_state, k_[:, c])) # state_add # d_w tk = einsum('ntv,nvk->ntk', d_state_decays, state) d_w[:, c] = tk # state_decays # backpropagate through time d_state.sub_(einsum('nsv,nsk->nvk', d_y_delta_c, q_[:, c])) # prev_output d_state.add_(einsum('ntv,ntk->nvk', d_state_decays, w[:, c])) # state_decays tk = einsum('nvk,ntk->ntv', d_state, k_[:, 0]) d_u[:, 0] = tk # state_add tk = einsum('nvk,ntv->ntk', d_state, u[:, 0]) d_k_[:, 0] = tk # state_add return d_q_.view(NH, T, D), d_k_.view(NH, T, D), d_w.view(NH, T, D), d_u.view(NH, T, D) class Delta(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, beta, chunk_size): w, u, y = forward(q, k, v, beta, chunk_size) ctx.save_for_backward(q, k, v, beta) ctx.chunk_size = chunk_size return y @staticmethod def backward(ctx, d_y): q, k, v, beta = ctx.saved_tensors NH, T, D = shape(q, k, v, beta) d_w = k.new_zeros(NH, T, D) d_u = v.new_zeros(NH, T, D) d_q, d_k, d_v, d_beta = backward(d_w, d_u, d_y, q, k, v, beta, chunk_size=ctx.chunk_size) return d_q, d_k, d_v, d_beta, None def test_delta(): NH, T, D = 2, 16, 3 q1, k1, v1, beta1 = make_example(NH, T, D) y0 = forward_loop(q1, k1, v1, beta1) chunk_size = 8 w1, u1, y1 = forward(q1, k1, v1, beta1, chunk_size=chunk_size) (y1 - torch.ones_like(y1).detach()).pow(2).mean().backward() assert allclose(y0, y1, atol=1e-5), 'y1 is wrong' q, k, v, beta = make_example(NH, T, D) y = Delta.apply(q, k, v, beta, chunk_size) (y - torch.ones_like(y).detach()).pow(2).mean().backward() assert allclose(y1, y, atol=1e-5), 'y is wrong' @@ -452,4 +480,5 @@ def test_delta_chunkwise_backward(): assert allclose(v1.grad, v.grad, atol=1e-5), 'v.grad is wrong' if __name__ == '__main__': test_delta() -
proger revised this gist
Jul 15, 2024 . 1 changed file with 26 additions and 24 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -336,46 +336,48 @@ def stitch_backward(d_y_delta, q, k, w, u, C, chunk_size): state_decay = einsum('nvk,ntk->ntv', state, w[:, c]) u[:, c] = u[:, c] - state_decay state_delta = einsum('ntv,ntk->nvk', u[:, c], k_[:, c]) if c < C-1: state = state + state_delta # from now on, u's are decayed # stitch backward for c in range(C-1, 0, -1): d_u[:, c] = einsum('nvk,ntk->ntv', d_state, k_[:, c]) # state_add if c < C-1: state_delta = einsum('ntv,ntk->nvk', u[:, c], k_[:, c]) state = state - state_delta # uncompute the state state_decay = einsum('nvk,ntk->ntv', state, w[:, c]) d_q1 = einsum('nsv,nvk->nsk', d_y_delta[:, c], state) # prev_output # causal_attend_backward for delta d_out_att = -d_y_delta[:, c] d_out_state_decays = einsum('nsv,ntv->nst', d_out_att, state_decay) d_out_state_decays.tril_() d_q2 = einsum('nst,ntk->nsk', d_out_state_decays, k_[:, c]) d_k1 = einsum('nst,nsk->ntk', d_out_state_decays, q_[:, c]) d_q_[:, c] = d_q1 + d_q2 d_k2 = einsum('nvk,ntv->ntk', d_state, u[:, c]) # state_add d_k_[:, c] = d_k1 + d_k2 d_state1 = einsum('nsv,nsk->nvk', d_y_delta[:, c], q_[:, c]) # prev_output qk = einsum('nsk,ntk->nst', q_[:, c], k_[:, c]) qk.tril_() d_state_decays1 = einsum('nsv,nst->ntv', d_out_att, qk) d_state_decays2 = einsum('nvk,ntk->ntv', -d_state, k_[:, c]) # state_add d_state_decays = d_state_decays1 + d_state_decays2 d_w[:, c] = einsum('ntv,nvk->ntk', d_state_decays, state) # state_decays d_state2 = einsum('ntv,ntk->nvk', d_state_decays, w[:, c]) # state_decays d_state = d_state + d_state1 + d_state2 d_u[:, 0] = einsum('nvk,ntk->ntv', d_state, k_[:, 0]) # state_add d_k_[:, 0] = einsum('nvk,ntv->ntk', d_state, u[:, 0]) # state_add return d_q_.view(NH, T, D), d_k_.view(NH, T, D), d_w.view(NH, T, D), d_u.view(NH, T, D) -
proger revised this gist
Jul 15, 2024 . 1 changed file with 96 additions and 68 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -36,7 +36,31 @@ import torch from torch import einsum, randn, allclose, stack, eye, manual_seed, no_grad, set_float32_matmul_precision, compile, arange #set_float32_matmul_precision('high') def tileprint(K, name='K'): "format matches tileprint in tk code so you can diff it" assert K.shape == (16, 16) for laneid in range(32): row_top = laneid // 4 row_bottom = row_top + 8 col_left = laneid % 4 * 2 col_right = col_left + 8 def fmt(r,c,tag): odd = "y" in tag if odd: # do not print r for odd rows because cuda printf silently runs out of function arguments return f"{name}[,{c:02}] {tag}={K[r,c]: .3f}" else: return f"{name}[{r:02},{c:02}] {tag}={K[r,c]: .3f}" print(f"lane={laneid:02}", " ".join([ " ".join([fmt(row_top, col_left, "0x"), fmt(row_top, col_left+1, "0y")]), " ".join([fmt(row_bottom, col_left, "1x"), fmt(row_bottom, col_left+1, "1y")]), " ".join([fmt(row_top, col_right, "2x"), fmt(row_top, col_right+1, "2y")]), " ".join([fmt(row_bottom, col_right, "3x"), fmt(row_bottom, col_right+1, "3y")]) ])) def decay_values(q, k, v, beta): @@ -53,9 +77,8 @@ def decay_values(q, k, v, beta): u[:, t] -= beta_[:, t] * einsum('nt,ntd->nd', K[:, :t, t], u[:, :t].clone()) # attend to decayed values qk = einsum("nsk,ntk->nst", q, k) qk.tril_() y = einsum("nst,ntj->nsj", qk, u) return w, u, y @@ -78,18 +101,29 @@ def forward_chunkwise(q, k, v, beta, chunk_size=2): u = u.view(NH, C, chunk_size, D) w = w.view(NH, C, chunk_size, D) # materialize the state for the leading chunk state = einsum('ntv,ntk->nvk', u[:, 0], k_[:, 0]) print(y, 'y') y_deltas = [u.new_zeros(NH, chunk_size, D)] for c in range(1, C): if c == 1: # early on this is cheaper than a state query kw = einsum('nsk,ntk->nst', k_[:, c-1], w[:, c]) # TDT u_old = einsum('nsv,nst->ntv', u[:, c-1], kw) # TDT kq = einsum('nsk,ntk->nst', k_[:, c-1], q_[:, c]) # TDT y_cur = einsum('nsv,nst->ntv', u[:, c-1], kq) # TTD else: u_old = einsum('nvk,ntk->ntv', state, w[:, c]) # DDT y_cur = einsum('nvk,nsk->nsv', state, q_[:, c]) # DDT # attend to old values qk = einsum("nsi,nti->nst", q_[:, c], k_[: ,c]) # TDT qk = qk.tril() y_prev = einsum("nst,ntv->nsv", qk, u_old) # TTD y_deltas.append(y_cur - y_prev) @@ -98,7 +132,11 @@ def forward_chunkwise(q, k, v, beta, chunk_size=2): y_delta = torch.stack(y_deltas, dim=1) w = w.view(NH, T, D) u = u.view(NH, T, D) y = y.view(NH, T, D) + y_delta.view(NH, T, D) return w, u, y def forward_loop(q, k, v, beta): @@ -134,19 +172,18 @@ def shape(q, k, v, beta=None): return NH, T, D def make_example(NH, T, D, device='cpu', dtype=torch.float32): manual_seed(0) q = randn(NH, T, D, device=device, dtype=dtype) / D**0.5 q.requires_grad_() k = randn(NH, T, D, device=device, dtype=dtype) / D**0.5 k.requires_grad_() v = randn(NH, T, D, device=device, dtype=dtype) / D**0.5 v.requires_grad_() beta = randn(NH, T, device=device, dtype=dtype).sigmoid() beta.requires_grad_() return q, k, v, beta @no_grad() def decay_values_backward(d_out_w, d_out_u, d_out_y, q, k, v, beta): @@ -162,6 +199,7 @@ def decay_values_backward(d_out_w, d_out_u, d_out_y, q, k, v, beta): v = v.clone() # load v beta = beta.clone() # load beta d_out_w = d_out_w.clone() # ntk d_out_y = d_out_y.clone() # ntv w = k.new_zeros(NH, T, D) # ntk u = v.new_zeros(NH, T, D) # ntw @@ -218,7 +256,7 @@ def decay_values_backward(d_out_w, d_out_u, d_out_y, q, k, v, beta): v.zero_() # reuse register space of v for d_out_u d_out_u = d_out_u.clone() # load ntw d_out_u += einsum('nst,nsv->ntv', tt, d_out_y) # # backward for d_k, d_v, d_beta @@ -275,93 +313,84 @@ def decay_values_backward(d_out_w, d_out_u, d_out_y, q, k, v, beta): return d_q, d_k, d_v, d_beta def stitch_backward(d_y_delta, q, k, w, u, C, chunk_size): NH, T, D = shape(q, k, None, None) d_y_delta = d_y_delta.view(NH, C, chunk_size, D) q_ = q.view(NH, C, chunk_size, D) k_ = k.view(NH, C, chunk_size, D) u = u.view(NH, C, chunk_size, D).clone() w = w.view(NH, C, chunk_size, D) d_q_ = q.new_zeros(NH, C, chunk_size, D) d_k_ = k.new_zeros(NH, C, chunk_size, D) d_w = w.new_zeros(NH, C, chunk_size, D) d_u = u.new_zeros(NH, C, chunk_size, D) d_state = w.new_zeros(NH, D, D) # NHVK # materialize the state for the leading chunk state = einsum('ntv,ntk->nvk', u[:, 0], k_[:, 0]) for c in range(1, C): state_decay = einsum('nvk,ntk->ntv', state, w[:, c]) u[:, c] = u[:, c] - state_decay state_delta = einsum('ntv,ntk->nvk', u[:, c], k_[:, c]) state = state + state_delta # from now on, u's are decayed # stitch backward for c in range(C-1, -1, -1): d_u[:, c] = einsum('nvk,ntk->ntv', d_state, k_[:, c]) # state_add if c == 0: d_k_[:, c] = einsum('nvk,ntv->ntk', d_state, u[:, c]) # state_add else: state_delta = einsum('ntv,ntk->nvk', u[:, c], k_[:, c]) state = state - state_delta # uncompute the state state_decay = einsum('nvk,ntk->ntv', state, w[:, c]) d_q1 = einsum('nsv,nvk->nsk', d_y_delta[:, c], state) # prev_output # causal_attend_backward for delta d_out_att = -d_y_delta[:, c] d_out_state_decays = einsum('nsv,ntv->nst', d_out_att, state_decay) d_out_state_decays.tril_() d_q2 = einsum('nst,ntk->nsk', d_out_state_decays, k_[:, c]) d_k1 = einsum('nst,nsk->ntk', d_out_state_decays, q_[:, c]) d_q_[:, c] = d_q1 + d_q2 d_k2 = einsum('nvk,ntv->ntk', d_state, u[:, c]) # state_add d_k_[:, c] = d_k1 + d_k2 d_state1 = einsum('nsv,nsk->nvk', d_y_delta[:, c], q_[:, c]) # prev_output qk = einsum('nsk,ntk->nst', q_[:, c], k_[:, c]) qk.tril_() d_state_decays1 = einsum('nsv,nst->ntv', d_out_att, qk) d_state_decays2 = einsum('nvk,ntk->ntv', -d_state, k_[:, c]) # state_add d_state_decays = d_state_decays1 + d_state_decays2 d_w[:, c] = einsum('ntv,nvk->ntk', d_state_decays, state) # state_decays d_state2 = einsum('ntv,ntk->nvk', d_state_decays, w[:, c]) # state_decays d_state = d_state + d_state1 + d_state2 return d_q_.view(NH, T, D), d_k_.view(NH, T, D), d_w.view(NH, T, D), d_u.view(NH, T, D) class DeltaChunkwise(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, beta, chunk_size): w, u, y = forward_chunkwise(q, k, v, beta, chunk_size) ctx.save_for_backward(q, k, v, beta, w, u) ctx.chunk_size = chunk_size return y @staticmethod def backward(ctx, d_y): q, k, v, beta, w, u = ctx.saved_tensors NH, T, D = shape(q, k, v, beta) chunk_size = ctx.chunk_size C = T // chunk_size @@ -371,8 +400,6 @@ def backward(ctx, d_y): v.view(NH*C, chunk_size, D), beta.view(NH*C, chunk_size) ) d_q_1, d_k_1, d_w, d_u = stitch_backward(d_y, q, k, w, u, C=C, chunk_size=chunk_size) d_w = d_w.view(NH*C, chunk_size, D) @@ -396,19 +423,20 @@ def backward(ctx, d_y): def test_delta_chunkwise_backward(): NH, T, D = 2, 16, 3 q1, k1, v1, beta1 = make_example(NH, T, D) y0 = forward_loop(q1, k1, v1, beta1) w1, u1, y1 = forward_chunkwise(q1, k1, v1, beta1, chunk_size=2) (y1 - torch.ones_like(y1).detach()).pow(2).mean().backward() assert allclose(y0, y1, atol=1e-5), 'y1 is wrong' q, k, v, beta = make_example(NH, T, D) y = DeltaChunkwise.apply(q, k, v, beta, 2) (y - torch.ones_like(y).detach()).pow(2).mean().backward() assert allclose(y1, y, atol=1e-5), 'y is wrong' # print(beta1.grad - beta.grad, 'beta.grad diff') -
proger revised this gist
Jul 12, 2024 . 1 changed file with 116 additions and 344 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -2,13 +2,12 @@ DeltaNet implementation reference for Accelerated Scan. DeltaNet performs efficient management of a large fixed-sized memory. For a simple single chunk version see `forward_simple`. It computes decayed values by a little bit of recurrence and then applies linear attention (`decay_values`). `forward_chunkwise` is inspired by Yang 2024. It applies single chunk version pointwise and then performs chunk-level stitching. forward_loop is the reference implementation of the original recurrence. References: [1] The WY Representation for Products of Householder Matrices (Bischof and Van Loan 1985) @@ -40,9 +39,9 @@ set_float32_matmul_precision('high') def decay_values(q, k, v, beta): "decay values applying deltanet forgetting rules" NH, T, D = shape(q, k, v, beta) beta_ = beta.unsqueeze(-1) w = beta_ * k.clone() @@ -53,21 +52,13 @@ def decay_values(k, v, beta): w[:, t] -= beta_[:, t] * einsum('nt,ntd->nd', K[:, :t, t], w[:, :t].clone()) u[:, t] -= beta_[:, t] * einsum('nt,ntd->nd', K[:, :t, t], u[:, :t].clone()) # attend to decayed values qk = einsum("nsi,nti->nst", q, k) mask = q.new_ones(T, T).tril() qk = einsum('nst,st->nst', qk, mask) y = einsum("nst,ntj->nsj", qk, u) return w, u, y def forward_chunkwise(q, k, v, beta, chunk_size=2): @@ -79,18 +70,9 @@ def forward_chunkwise(q, k, v, beta, chunk_size=2): ) # evaluate all chunks in parallel w, u, y = decay_values(q_, k_, v_, beta_) # stitch chunks sequentially q_ = q.view(NH, C, chunk_size, D) k_ = k.view(NH, C, chunk_size, D) u = u.view(NH, C, chunk_size, D) @@ -99,28 +81,27 @@ def stitch_forward(q, k, w, u, C, chunk_size): # materialize the state for the leading chunk state = einsum('ntv,ntk->nvk', u[:, 0], k_[:, 0]) y_deltas = [u.new_zeros(NH, chunk_size, D)] for c in range(1, C): u_old = einsum('nvk,ntk->ntv', state, w[:, c]) y_cur = einsum('nvk,nsk->nsv', state, q_[:, c]) # attend to old values qk = einsum("nsi,nti->nst", q_[:, c], k_[: ,c]) qk = qk.tril() y_prev = einsum("nst,ntj->nsj", qk, u_old) y_deltas.append(y_cur - y_prev) state_delta = einsum('ntv,ntk->nvk', u[:, c] - u_old, k_[:, c]) state = state + state_delta y_delta = torch.stack(y_deltas, dim=1) return y.view(NH, T, D) + y_delta.view(NH, T, D) def forward_loop(q, k, v, beta): "reference: w_t = w_{t-1} + beta_t (v_t - w_t k_t) k_t" NH, T, D = shape(q, k, v, beta) @@ -142,30 +123,6 @@ def forward_ogloop(q, k, v, beta): return stack(y, dim=1) def shape(q, k, v, beta=None): NH, T, D = (q if q is not None else k).shape if q is not None: @@ -189,114 +146,18 @@ def make_example(NH, T, D): beta.requires_grad_() return q, k, v, beta #%% @no_grad() def decay_values_backward(d_out_w, d_out_u, d_out_y, q, k, v, beta): NH, T, D = shape(q, k, v, beta) # # allocations # # this group is loaded from global memory q = q.clone() # load q k = k.clone() # load k v = v.clone() # load v beta = beta.clone() # load beta @@ -341,137 +202,82 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta): u_bases = u_bases - v # # causal_attend_backward for d_q, d_k_2, d_out_u # tt = einsum('nsv,ntv->nst', d_out_y, u) tt = tt.tril() d_q = einsum('nst,ntk->nsk', tt, k) d_q.clone() # store d_k_2 = einsum('nst,nsk->ntk', tt, q) d_k_2.clone() # store to shared memory? tt = einsum('nsk,ntk->nst', q, k) tt = tt.tril() v.zero_() # reuse register space of v for d_out_u d_out_u = d_out_u.clone() # load ntw d_out_u += einsum('nsv,nst->ntv', d_out_y, tt) # # backward for d_k, d_v, d_beta # d_k.zero_() for t in range(T-1,-1,-1): # d_k tt = einsum('njw,ntw->njt', w, d_out_w) # matmul for the sake of one column t tt[:, t:, :] = 0 tk = einsum('njt,njk->ntk', tt, k) tt = einsum('njv,ntv->njt', u, d_out_u) # matmul for the sake of one column t tt[:, t:, :] = 0 tk += einsum('njt,njk->ntk', tt, k) d_k[:, t] += tk[:, t] # backpropagate through time, updating only remaining timestamps tt.zero_() tt[:, t] += bKl[:, t] tk = einsum('ntj,ntk->njk', tt, d_out_w) d_out_w = d_out_w - tk tk = einsum('ntj,ntk->njk', tt, d_out_u) d_out_u = d_out_u - tk d_k = d_out_w - d_k d_k = einsum('ntk,nt->ntk', d_k, beta) # decay w and u tt = einsum('ntw,njw->ntj', d_out_w, w) tt += einsum('ntw,njw->ntj', d_out_u, u) tt.tril_(diagonal=-1) tk = einsum('ntj,ntk->njk', tt, bk) d_k = d_k - tk d_k_2 = d_k_2.clone() # load from shared memory d_k = d_k_2 + d_k d_k = d_k.clone() # store # d_beta w_bases = einsum('ntk,ntk->ntk', w_bases, d_out_w) u_bases = einsum('ntw,ntw->ntw', u_bases, d_out_u) # d_v using d_out_u register d_out_u = einsum('nt,ntv->ntv', beta, d_out_u) d_v = d_out_u.clone() # store # continue d_beta reusing the beta register beta = einsum('ntk->nt', w_bases) beta += einsum('ntv->nt', u_bases) d_beta = beta.clone() # store return d_q, d_k, d_v, d_beta # %% def stitch_backward(d_y_delta, q, k, w, u, C, chunk_size): NH, T, D = shape(q, k, None, None) @@ -493,92 +299,58 @@ def stitch_backward(d_y_delta, q, k, w, u, C, chunk_size): # materialize the state for the leading chunk states[:, 0] = einsum('ntv,ntk->nvk', u[:, 0], k_[:, 0]) # stitch forward for c in range(1, C): u_old = einsum('nvk,ntk->ntv', states[:, c-1], w[:, c]) y_cur = einsum('nvk,nsk->nsv', states[:, c-1], q_[:, c]) # attend to old values qk = einsum("nsi,nti->nst", q_[:, c], k_[: ,c]) qk = qk.tril() y_prev = einsum("nst,ntj->nsj", qk, u_old) y_delta[:, c] = y_cur - y_prev state_delta = einsum('ntv,ntk->nvk', u[:, c] - u_old, k_[:, c]) states[:, c] = states[:, c-1] + state_delta # stitch backward for c in range(C-1, -1, -1): if c == 0: prev_state = torch.zeros_like(d_state) else: prev_state = states[:, c-1] state_decays = einsum('nvk,ntk->ntv', prev_state, w[:, c]) d_q1 = einsum('nsv,nvk->nsk', d_y_delta[:, c], prev_state) # prev_output d_state1 = einsum('nsv,nsk->nvk', d_y_delta[:, c], q_[:, c]) # prev_output # causal_attend_backward for delta mask = q.new_ones(T, T).tril() d_out_att = -d_y_delta[:, c] d_out_state_decays = einsum('nsv,ntv->nst', d_out_att, state_decays) d_out_state_decays.tril_() d_q2 = einsum('nst,ntk->nsk', d_out_state_decays, k_[:, c]) d_k1 = einsum('nst,nsk->ntk', d_out_state_decays, q_[:, c]) qk = einsum('nsk,ntk->nst', q_[:, c], k_[:, c]) qk.tril_() d_state_decays1 = einsum('nsv,nst->ntv', d_out_att, qk) d_k2 = einsum('nvk,ntv->ntk', d_state, u[:, c] - state_decays) # state_add d_u[:, c] = einsum('nvk,ntk->ntv', d_state, k_[:, c]) # state_add d_state_decays2 = einsum('nvk,ntk->ntv', -d_state, k_[:, c]) # state_add d_state_decays = d_state_decays1 + d_state_decays2 d_w[:, c] = einsum('ntv,nvk->ntk', d_state_decays, prev_state) # state_decays d_state2 = einsum('ntv,ntk->nvk', d_state_decays, w[:, c]) # state_decays d_state = d_state + d_state1 + d_state2 d_q_[:, c] = d_q1 + d_q2 d_k_[:, c] = d_k1 + d_k2 return d_q_.view(NH, T, D), d_k_.view(NH, T, D), d_w.view(NH, T, D), d_u.view(NH, T, D) class DeltaChunkwise(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, beta, chunk_size): @@ -599,27 +371,24 @@ def backward(ctx, d_y): v.view(NH*C, chunk_size, D), beta.view(NH*C, chunk_size) ) w, u, y = decay_values(q_, k_, v_, beta_) d_q_1, d_k_1, d_w, d_u = stitch_backward(d_y, q, k, w, u, C=C, chunk_size=chunk_size) d_w = d_w.view(NH*C, chunk_size, D) d_u = d_u.view(NH*C, chunk_size, D) d_y = d_y.view(NH*C, chunk_size, D) u = u.view(NH*C, chunk_size, D) d_q_2, d_k_2, d_v_, d_beta_ = decay_values_backward(d_w, d_u, d_y, q_, k_, v_, beta_) d_q_1 = d_q_1.view(NH, T, D) d_q_2 = d_q_2.view(NH, C, chunk_size, D) d_q_2 = d_q_2.reshape(NH, T, D) d_q = d_q_1 + d_q_2 d_k_2 = d_k_2.reshape(NH, T, D) d_k = d_k_1 + d_k_2 d_v = d_v_.reshape(NH, T, D) d_beta = d_beta_.view(NH, T) @@ -629,6 +398,8 @@ def backward(ctx, d_y): def test_delta_chunkwise_backward(): NH, T, D = 2, 16, 2 q1, k1, v1, beta1 = make_example(NH, T, D) y0 = forward_loop(q1, k1, v1, beta1) y1 = forward_chunkwise(q1, k1, v1, beta1, chunk_size=2) (y1 - torch.ones_like(y1).detach()).pow(2).mean().backward() @@ -637,6 +408,7 @@ def test_delta_chunkwise_backward(): y = DeltaChunkwise.apply(q, k, v, beta, 2) (y - torch.ones_like(y).detach()).pow(2).mean().backward() assert allclose(y0, y1, atol=1e-5), 'y1 is wrong' assert allclose(y1, y, atol=1e-5), 'y is wrong' # print(beta1.grad - beta.grad, 'beta.grad diff') -
proger revised this gist
Jul 11, 2024 . 1 changed file with 32 additions and 30 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -301,53 +301,53 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta): v = v.clone() # load v beta = beta.clone() # load beta d_out_w = d_out_w.clone() # ntk w = k.new_zeros(NH, T, D) # ntk u = v.new_zeros(NH, T, D) # ntw w_bases = w.clone() # ntk u_bases = u.clone() # ntw bk = einsum('nt,ntk->ntk', beta, k) bKl = k.new_zeros(NH, T, T) tt = k.new_zeros(NH, T, T) d_k = k.new_zeros(NH, T, D) # nsk tk = k.new_zeros(NH, T, D) # ntk # # forward # tt = einsum('ntk,nsk->nts', k, k) tt = tt.tril(diagonal=-1) # make_causal(0); set_diagonal(0) bKl = einsum('nt,nts->nts', beta, tt) # multiply each row of K by beta u_bases = v v = einsum('nt,ntw->ntw', beta, v) for t in range(T): tk = einsum('nts,nsk->ntk', bKl, w) # matmul for the sake of one row w[:, t] = bk[:, t, :] - tk[:, t, :] tk = einsum('nts,nsw->ntw', bKl, u) # matmul for the sake of one row u[:, t] = v[:, t, :] - tk[:, t, :] w.clone() # store w u.clone() # store u w_bases = einsum('nts,nsk->ntk', tt, w) w_bases = k - w_bases v = einsum('nts,nsw->ntw', tt, u) u_bases = u_bases - v # # backward for d_k, d_v, d_beta # v.zero_() # reuse register space of v for d_out_u d_out_u = d_out_u.clone() # ntw d_k.zero_() for t in range(T-1,-1,-1): # d_k tt = einsum('njw,ntw->njt', w, d_out_w) # matmul for the sake of one column t @@ -359,37 +359,39 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta): d_k[:, t] += tk[:, t] # backpropagate through time, updating only remaining timestamps tt.zero_() tt[:, t] += bKl[:, t] tk = einsum('ntj,ntk->njk', tt, d_out_w) d_out_w -= tk tk = einsum('ntj,ntk->njk', tt, d_out_u) d_out_u -= tk d_k = d_out_w - d_k d_k = einsum('nt,ntk->ntk', beta, d_k) # decay w and u tt.zero_() # reuse K again tt += einsum('njw,ntw->ntj', w, d_out_w) tt += einsum('njw,ntw->ntj', u, d_out_u) tt.tril_(diagonal=-1) tk = einsum('ntj,ntk->njk', tt, bk) d_k -= tk d_k = d_k.clone() # store # d_beta w_bases *= d_out_w u_bases *= d_out_u # d_v using d_out_u register d_out_u = einsum('nt,ntv->ntv', beta, d_out_u) d_v = d_out_u.clone() # store # continue d_beta reusing the beta register beta = einsum('ntk->nt', w_bases) beta += einsum('ntk->nt', u_bases) d_beta = beta.clone() # store return d_k, d_v, d_beta -
proger revised this gist
Jul 11, 2024 . 1 changed file with 44 additions and 54 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -296,9 +296,12 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta): # allocations # # this group is loaded from global memory k = k.clone() # load k v = v.clone() # load v beta = beta.clone() # load beta d_out_w = d_out_w.clone() # ntk d_out_u = d_out_u.clone() # ntw w = k.new_zeros(NH, T, D) # ntk u = v.new_zeros(NH, T, D) # ntw @@ -308,93 +311,80 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta): bk = einsum('nt,ntk->ntk', beta, k) bv = einsum('nt,ntw->ntw', beta, v) K = k.new_zeros(NH, T, T) # reused later twice in the backward pass bKl = k.new_zeros(NH, T, T) tt = k.new_zeros(NH, T, T) d_k = k.new_zeros(NH, T, D) # nsk tk = k.new_zeros(NH, T, D) # ntk tvec_reg = k.new_zeros(NH, T) # nt kvec_reg = k.new_zeros(NH, D) # nw # # forward # K = einsum('ntd,nsd->nts', k, k) K = K.tril(diagonal=-1) # make_causal(0); set_diagonal(0) bKl = einsum('nt,nts->nts', beta, K) # multiply each row of K by beta for t in range(T): c_w = einsum('nts,nsk->ntk', bKl, w) # matmul for the sake of one row w[:, t] = bk[:, t, :] - c_w[:, t, :] c_u = einsum('nts,nsw->ntw', bKl, u) # matmul for the sake of one row u[:, t] = bv[:, t, :] - c_u[:, t, :] w.clone() # store w u.clone() # store u w_bases = einsum('nts,nsk->ntk', K, w) w_bases = k - w_bases u_bases = einsum('nts,nsw->ntw', K, u) u_bases = v - u_bases # # backward for d_k, d_v, d_beta # for t in range(T-1,-1,-1): # d_k tt = einsum('njw,ntw->njt', w, d_out_w) # matmul for the sake of one column t tt[:, t:, :] = 0 tk = einsum('njt,njk->ntk', tt, k) tt = einsum('njv,ntv->njt', u, d_out_u) # matmul for the sake of one column t tt[:, t:, :] = 0 tk += einsum('njt,njk->ntk', tt, k) d_k[:, t] += tk[:, t] # backpropagate through time, updating only remaining timestamps K.zero_() # reuse K K[:, t] += bKl[:, t] tk = einsum('ntj,ntk->njk', K, d_out_w) d_out_w -= tk tk = einsum('ntj,ntk->njk', K, d_out_u) d_out_u -= tk d_k = d_out_w - d_k d_k = einsum('nt,ntk->ntk', beta, d_k) # decay w and u K.zero_() # reuse K again K += einsum('njw,ntw->ntj', w, d_out_w) K += einsum('njw,ntw->ntj', u, d_out_u) K.tril_(diagonal=-1) tk = einsum('ntj,ntk->njk', K, bk) d_k -= tk d_k = d_k.clone() # store # d_beta w_bases *= d_out_w tvec_reg = einsum('ntk->nt', w_bases) u_bases *= d_out_u tvec_reg += einsum('ntk->nt', u_bases) d_beta = tvec_reg.clone() # store # d_v d_out_u = einsum('nt,ntv->ntv', beta, d_out_u) -
proger revised this gist
Jul 10, 2024 . 1 changed file with 8 additions and 5 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -308,7 +308,7 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta): bk = einsum('nt,ntk->ntk', beta, k) bv = einsum('nt,ntw->ntw', beta, v) K = k.new_zeros(NH, T, T) # reused later as t_regs_uw bKl = k.new_zeros(NH, T, T) d_k = k.new_zeros(NH, T, D) # nsk @@ -347,6 +347,9 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta): # backward for d_k, d_v, d_beta # t_regs_uw = K t_regs_uw.zero_() for t in range(T-1,-1,-1): # d_k w[:, t, :] = 0 @@ -370,20 +373,20 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta): w_reg = d_out_w[:, t] tk = einsum('nw,ntw->ntw', w_reg, w) t_reg = einsum('ntw->nt', tk) t_regs_uw[:, t] += t_reg # decay u w_reg = d_out_u[:, t] tk = einsum('nw,ntw->ntw', w_reg, u) t_reg = einsum('ntw->nt', tk) t_regs_uw[:, t] += t_reg # backpropagate through time d_out_w += einsum('nj,nk->njk', bKl[:, t, :], d_out_w[:, t]) d_out_u += einsum('nj,nk->njk', bKl[:, t, :], d_out_u[:, t]) tk = einsum('nst,nsk->ntk', t_regs_uw, bk) d_k -= tk d_k = d_k.clone() # store # d_beta -
proger revised this gist
Jul 10, 2024 . 1 changed file with 64 additions and 36 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -296,78 +296,106 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta): # allocations # k = k.clone() # load k v = v.clone() # load v beta = beta.clone() # load beta w = k.new_zeros(NH, T, D) # ntk u = v.new_zeros(NH, T, D) # ntw w_bases = w.clone() # ntk u_bases = u.clone() # ntw bk = einsum('nt,ntk->ntk', beta, k) bv = einsum('nt,ntw->ntw', beta, v) K = k.new_zeros(NH, T, T) bKl = k.new_zeros(NH, T, T) d_k = k.new_zeros(NH, T, D) # nsk d_out_w = d_out_w.clone() # ntk d_out_u = d_out_u.clone() # ntw tk = k.new_zeros(NH, T, D) # ntk t_reg = k.new_zeros(NH, T) # nt w_reg = k.new_zeros(NH, D) # nw # # forward # K = einsum('ntd,nsd->nts', k, k) K = K.tril(diagonal=-1) # make_causal(0); set_diagonal(0) bKl = einsum('nt,nts->nts', -beta, K) # multiply each row of K by beta for t in range(T): c_w = einsum('nts,nsk->ntk', bKl, w) w[:, t] = bk[:, t, :] + c_w[:, t, :] c_u = einsum('nts,nsw->ntw', bKl, u) u[:, t] = bv[:, t, :] + c_u[:, t, :] w_bases = einsum('nts,nsk->ntk', K, w) w_bases = k - w_bases u_bases = einsum('nts,nsw->ntw', K, u) u_bases = v - u_bases w.clone() # store w u.clone() # store u # # backward for d_k, d_v, d_beta # for t in range(T-1,-1,-1): # d_k w[:, t, :] = 0 k[:, t, :] = 0 u[:, t, :] = 0 # wk t_reg = einsum('nw,njw->nj', d_out_w[:, t], w) w_reg = einsum('nj,njk->nk', t_reg, k) w_reg = d_out_w[:, t] - w_reg w_reg = einsum('n,nk->nk', beta[:,t], w_reg) d_k[:, t] += w_reg # uk t_reg = einsum('nw,njw->nj', d_out_u[:, t], u) w_reg = einsum('nj,njk->nk', t_reg, k) w_reg = einsum('n,nk->nk', beta[:,t], w_reg) d_k[:, t] -= w_reg # decay w w_reg = d_out_w[:, t] tk = einsum('nw,ntw->ntw', w_reg, w) t_reg = einsum('ntw->nt', tk) w_reg = bk[:, t] d_k -= einsum('nt,nk->ntk', t_reg, w_reg) # decay u w_reg = d_out_u[:, t] tk = einsum('nw,ntw->ntw', w_reg, u) t_reg = einsum('ntw->nt', tk) w_reg = bk[:, t] d_k -= einsum('nt,nk->ntk', t_reg, w_reg) # backpropagate through time d_out_w += einsum('nj,nk->njk', bKl[:, t, :], d_out_w[:, t]) d_out_u += einsum('nj,nk->njk', bKl[:, t, :], d_out_u[:, t]) d_k = d_k.clone() # store # d_beta w_bases *= d_out_w t_reg = einsum('ntk->nt', w_bases) u_bases *= d_out_u t_reg += einsum('ntk->nt', u_bases) d_beta = t_reg.clone() # store # d_v d_out_u = einsum('nt,ntv->ntv', beta, d_out_u) d_v = d_out_u.clone() # store return d_k, d_v, d_beta @@ -386,7 +414,7 @@ def backward(ctx, d_out_w, d_out_u): def test_equal_decay_values_backward(): NH, T, D = 1, 16, 16 q, k, v, beta = make_example(NH, T, D) w, u = decay_values(k, v, beta) -
proger revised this gist
Jul 10, 2024 . 1 changed file with 23 additions and 11 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -292,9 +292,18 @@ def test_equal_attend_backward2(atol=1e-5): def decay_values_backward(d_out_w, d_out_u, k, v, beta): NH, T, D = shape(None, k, v, beta) # # allocations # eye = torch.eye(D, device=k.device, dtype=k.dtype) eye = eye.unsqueeze(0).expand(NH, D, D) # recompute w and u TK-style w = k.new_zeros(NH, T, D) # ntk u = v.new_zeros(NH, T, D) # ntw w_bases = w.clone() u_bases = u.clone() bk = einsum('nt,ntk->ntk', beta, k) bv = einsum('nt,ntw->ntw', beta, v) @@ -303,28 +312,31 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta): K = K.tril(diagonal=-1) # make_causal(0); set_diagonal(0) bKl = einsum('nt,nts->nts', -beta, K) # multiply each row of K by beta d_k = k.new_zeros(NH, T, D) # nsk d_beta = beta.new_zeros(NH, T) # ns d_v = v.new_zeros(NH, T, D) # nsv d_out_w_backward = d_out_w #.clone() # ntk # doesn't seem to need a copy but why? d_out_u_backward = d_out_u.clone() # ntw # # forward # for t in range(T): c_w = einsum('nts,nsk->ntk', bKl, w) w[:, t] = bk[:, t, :] + c_w[:, t, :] c_u = einsum('nts,nsw->ntw', bKl, u) u[:, t] = bv[:, t, :] + c_u[:, t, :] w_bases = k - einsum('nts,nsk->ntk', K, w) u_bases = v - einsum('nts,nsw->ntw', K, u) w0 = w.clone() # we will be mutating these, so store original w and u here u0 = u.clone() # # backward for d_k, d_v, d_beta # for t in range(T-1,-1,-1): w[:, t, :] = 0 -
proger revised this gist
Jul 10, 2024 . 1 changed file with 7 additions and 7 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -346,17 +346,17 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta): d_k[:, :t] -= einsum('nk,ns->nsk', bk[:, t], decay_w) d_k[:, :t] -= einsum('nk,ns->nsk', bk[:, t], decay_u) # backpropagate through time d_out_w_backward[:, :t] += einsum('nj,nk->njk', bKl[:, t, :t], d_out_w_backward[:, t]) d_out_u_backward[:, :t] += einsum('nj,nk->njk', bKl[:, t, :t], d_out_u_backward[:, t]) # d_beta d_beta += einsum('ntk,ntk->nt', w_bases, d_out_w_backward) d_beta += einsum('ntk,ntk->nt', u_bases, d_out_u_backward) # d_v d_v = einsum('nt,ntv->ntv', beta, d_out_u_backward) return d_k, d_v, d_beta -
proger revised this gist
Jul 10, 2024 . 1 changed file with 34 additions and 26 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -320,34 +320,42 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta): w_bases = k - einsum('nts,nsk->ntk', K, w) u_bases = v - einsum('nts,nsw->ntw', K, u) w0 = w.clone() # we will be mutating these, but the kernel also returns the original w and u u0 = u.clone() d_out_w_backward = d_out_w.clone() # ntk d_out_u_backward = d_out_u.clone() # ntw for t in range(T-1,-1,-1): w[:, t, :] = 0 k[:, t, :] = 0 u[:, t, :] = 0 wk = einsum('njw,njk->nwk', w, k) wk = eye - wk wk = einsum('n,nwk->nwk', beta[:, t], wk) uk = einsum('njw,njk->nwk', u, k) uk = einsum('n,nwk->nwk', beta[:, t], uk) # d_k d_k[:, t] += einsum('nw,nwk->nk', d_out_w_backward[:, t], wk) d_k[:, t] -= einsum('nw,nwk->nk', d_out_u_backward[:, t], uk) decay_w = einsum('nw,nsw->ns', d_out_w_backward[:, t], w[:, :t]) decay_u = einsum('nw,nsw->ns', d_out_u_backward[:, t], u[:, :t]) d_k[:, :t] -= einsum('nk,ns->nsk', bk[:, t], decay_w) d_k[:, :t] -= einsum('nk,ns->nsk', bk[:, t], decay_u) # d_beta d_beta[:, t] += einsum('nk,nk->n', w_bases[:, t], d_out_w_backward[:, t]) d_beta[:, t] += einsum('nk,nk->n', u_bases[:, t], d_out_u_backward[:, t]) # d_v d_v[:, t] = einsum('n,nv->nv', beta[:, t], d_out_u_backward[:, t]) # backpropagate through time d_out_w_backward[:, :t] += einsum('nj,nk->njk', bKl[:, t, :t], d_out_w_backward[:, t]) d_out_u_backward[:, :t] += einsum('nj,nk->njk', bKl[:, t, :t], d_out_u_backward[:, t]) return d_k, d_v, d_beta @@ -366,32 +374,32 @@ def backward(ctx, d_out_w, d_out_u): def test_equal_decay_values_backward(): NH, T, D = 1, 16, 3 q, k, v, beta = make_example(NH, T, D) w, u = decay_values(k, v, beta) (w + u - torch.ones_like(w)).pow(2).mean().backward() #(w - torch.ones_like(w)).pow(2).mean().backward() q1, k1, v1, beta1 = make_example(NH, T, D) w1, u1 = DecayValues.apply(k1, v1, beta1) (w1 + u1 - torch.ones_like(w1)).pow(2).mean().backward() #(w1 - torch.ones_like(w1)).pow(2).mean().backward() # print(v.grad, 'v.grad', v.grad.shape) # print(v1.grad, 'v1.grad') # print(v.grad - v1.grad, 'v diff') assert allclose(v.grad, v1.grad, atol=1e-5), 'v1_grad is wrong' # print(beta.grad, 'beta.grad du') # print(beta1.grad, 'beta1.grad du') assert allclose(beta.grad, beta1.grad, atol=1e-5), 'beta1.grad is wrong' # print(k.grad, 'k.grad du') # print(k1.grad, 'k1.grad du') # print(k.grad - k1.grad, 'diff du') assert allclose(k.grad, k1.grad, atol=1e-5), 'k1_grad is wrong' test_equal_decay_values_backward() # %% -
proger revised this gist
Jul 10, 2024 . 1 changed file with 43 additions and 113 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -291,133 +291,63 @@ def test_equal_attend_backward2(atol=1e-5): @no_grad() def decay_values_backward(d_out_w, d_out_u, k, v, beta): NH, T, D = shape(None, k, v, beta) # recompute w and u TK-style w = k.new_zeros(NH, T, D) # ntk u = v.new_zeros(NH, T, D) # ntw bk = einsum('nt,ntk->ntk', beta, k) bv = einsum('nt,ntw->ntw', beta, v) K = einsum('ntd,nsd->nts', k, k) # (T,T) matrix K = K.tril(diagonal=-1) # make_causal(0); set_diagonal(0) bKl = einsum('nt,nts->nts', -beta, K) # multiply each row of K by beta for t in range(T): c_w = einsum('nts,nsk->ntk', bKl, w) w[:, t] = bk[:, t, :] + c_w[:, t, :] c_u = einsum('nts,nsw->ntw', bKl, u) u[:, t] = bv[:, t, :] + c_u[:, t, :] # compute gradients for d_k, d_v, d_beta d_k = k.new_zeros(NH, T, D) # nsk d_beta = beta.new_zeros(NH, T) # ns d_v = v.new_zeros(NH, T, D) # nsv eye = torch.eye(D, device=k.device, dtype=k.dtype) eye = eye.unsqueeze(0).expand(NH, D, D) w_bases = k - einsum('nts,nsk->ntk', K, w) u_bases = v - einsum('nts,nsw->ntw', K, u) for t in range(T-1,-1,-1): wk = einsum('njw,njk->nwk', w[:, :t], k[:, :t]) wk = eye - wk uk = einsum('njw,njk->nwk', u[:, :t], k[:, :t]) # d_k wst = einsum('n,nwk->nwk', beta[:, t], wk) d_k[:, t] += einsum('nw,nwk->nk', d_out_w[:, t], wst) ust = einsum('n,nwk->nwk', beta[:, t], uk) d_k[:, t] -= einsum('nw,nwk->nk', d_out_u[:, t], ust) decay_w = einsum('nw,nsw->ns', d_out_w[:, t], w[:, :t]) decay_u = einsum('nw,nsw->ns', d_out_u[:, t], u[:, :t]) d_k[:, :t] -= einsum('nk,ns->nsk', bk[:, t], decay_w) d_k[:, :t] -= einsum('nk,ns->nsk', bk[:, t], decay_u) # d_beta d_beta[:, t] += einsum('nk,nk->n', w_bases[:, t], d_out_w[:, t]) d_beta[:, t] += einsum('nk,nk->n', u_bases[:, t], d_out_u[:, t]) # d_v d_v[:, t] = einsum('n,nv->nv', beta[:, t], d_out_u[:, t]) # backpropagate through time d_out_w += einsum('nj,nk->njk', bKl[:, t, :], d_out_w[:, t]) d_out_u += einsum('nj,nk->njk', bKl[:, t, :], d_out_u[:, t]) return d_k, d_v, d_beta @@ -436,7 +366,7 @@ def backward(ctx, d_out_w, d_out_u): def test_equal_decay_values_backward(): NH, T, D = 1, 4, 1 q, k, v, beta = make_example(NH, T, D) w, u = decay_values(k, v, beta) @@ -677,4 +607,4 @@ def test_delta_chunkwise_backward(): assert allclose(v1.grad, v.grad, atol=1e-5), 'v.grad is wrong' test_delta_chunkwise_backward() -
proger revised this gist
Jul 2, 2024 . 1 changed file with 32 additions and 32 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -120,38 +120,6 @@ def stitch1_forward(state, q, k, w, u): return y, state + state_add def forward_ogloop(q, k, v, beta): "reference: w_t = w_{t-1} + beta_t (v_t - w_t k_t) k_t" NH, T, D = shape(q, k, v, beta) @@ -499,6 +467,38 @@ def test_equal_decay_values_backward(): # %% def stitch1_backward(d_y, d_state, state, q, k, w, u): state_decays = einsum('nvk,ntk->ntv', state, w) d_q1 = einsum('nsv,nvk->nsk', d_y, state) # prev_output d_state1 = einsum('nsv,nsk->nvk', d_y, q) # prev_output d_q2, d_k1, d_state_decays1 = causal_attend_backward(-d_y, q, k, state_decays) # delta d_k2 = einsum('nvk,ntv->ntk', d_state, u - state_decays) # state_add d_u = einsum('nvk,ntk->ntv', d_state, k) # state_add d_state_decays2 = einsum('nvk,ntk->ntv', -d_state, k) # state_add d_state_decays = d_state_decays1 + d_state_decays2 d_w = einsum('ntv,nvk->ntk', d_state_decays, state) # state_decays d_state2 = einsum('ntv,ntk->nvk', d_state_decays, w) # state_decays return d_state + d_state1 + d_state2, d_q1 + d_q2, d_k1 + d_k2, d_w, d_u class Stitch1(torch.autograd.Function): @staticmethod def forward(ctx, state, q, k, w, u): y, new_state = stitch1_forward(state, q, k, w, u) ctx.save_for_backward(state, q, k, w, u) return y, new_state @staticmethod def backward(ctx, d_y, d_state): state, q, k, w, u = ctx.saved_tensors return stitch1_backward(d_y, d_state, state, q, k, w, u) def stitch_backward(d_y_delta, q, k, w, u, C, chunk_size): NH, T, D = shape(q, k, None, None) -
proger revised this gist
Jul 2, 2024 . 1 changed file with 12 additions and 44 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -83,11 +83,11 @@ def forward_chunkwise(q, k, v, beta, chunk_size=2): y = causal_attend(q_, k_, u) # stitch chunks sequentially y_delta, _ = stitch_forward(q, k, w, u, C=C, chunk_size=chunk_size) return y.view(NH, T, D) + y_delta.view(NH, T, D) def stitch_forward(q, k, w, u, C, chunk_size): "stitch chunks sequentially" NH, T, D = shape(q, k, None, None) @@ -96,15 +96,17 @@ def forward_stitch(q, k, w, u, C, chunk_size): u = u.view(NH, C, chunk_size, D) w = w.view(NH, C, chunk_size, D) # materialize the state for the leading chunk state = einsum('ntv,ntk->nvk', u[:, 0], k_[:, 0]) deltas = [u.new_zeros(NH, chunk_size, D)] for c in range(1, C): y_delta1, state = stitch1_forward(state, q_[:, c], k_[:, c], w[:, c], u[:, c]) deltas.append(y_delta1) y_delta = torch.stack(deltas, dim=1) return y_delta, state def stitch1_forward(state, q, k, w, u): @@ -496,27 +498,6 @@ def test_equal_decay_values_backward(): # %% def stitch_backward(d_y_delta, q, k, w, u, C, chunk_size): NH, T, D = shape(q, k, None, None) @@ -628,23 +609,10 @@ def test_stitch_all(atol=1e-5): class DeltaChunkwise(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, beta, chunk_size): y = forward_chunkwise(q, k, v, beta, chunk_size) ctx.save_for_backward(q, k, v, beta) ctx.chunk_size = chunk_size return y @staticmethod def backward(ctx, d_y): @@ -685,7 +653,7 @@ def backward(ctx, d_y): return d_q, d_k, d_v, d_beta, None def test_delta_chunkwise_backward(): NH, T, D = 2, 16, 2 q1, k1, v1, beta1 = make_example(NH, T, D) @@ -709,4 +677,4 @@ def test_delta_simple_backward(): assert allclose(v1.grad, v.grad, atol=1e-5), 'v.grad is wrong' test_delta_chunkwise_backward() -
proger revised this gist
Jul 2, 2024 . 1 changed file with 4 additions and 22 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -318,24 +318,6 @@ def test_equal_attend_backward2(atol=1e-5): #%% @no_grad() def decay_values_backward(d_out_w, d_out_u, k, v, beta): NH, T, D = shape(None, k, v, beta) @@ -473,7 +455,7 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta): class DecayValues(torch.autograd.Function): @staticmethod def forward(ctx, k, v, beta): w, u = decay_values(k, v, beta) ctx.save_for_backward(k, v, beta) return w, u @@ -487,7 +469,7 @@ def test_equal_decay_values_backward(): NH, T, D = 2, 8, 3 q, k, v, beta = make_example(NH, T, D) w, u = decay_values(k, v, beta) #(w + u - torch.ones_like(w)).pow(2).mean().backward() (w - torch.ones_like(w)).pow(2).mean().backward() @@ -591,7 +573,7 @@ def test_stitch_all(atol=1e-5): NH, T, D = 1, 4, 2 C, chunk_size = 2, 2 q, k, v, beta = make_example(NH, T, D) w, u = decay_values(k, v, beta) q.requires_grad_() k.requires_grad_() @@ -612,7 +594,7 @@ def test_stitch_all(atol=1e-5): # print(u.grad, 'u.grad') q1, k1, v1, beta1 = make_example(NH, T, D) w1, u1 = decay_values(k1, v1, beta1) q1.requires_grad_() k1.requires_grad_() -
proger revised this gist
Jul 2, 2024 . 1 changed file with 66 additions and 26 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -324,14 +324,15 @@ def decay_values_forward(k, v, beta): w = k.new_zeros(NH, T, D) u = v.new_zeros(NH, T, D) beta_ = beta.unsqueeze(-1) K = einsum('nsd,ntd->nst', k, k) # (T,T) matrix w[:, 0] = beta_[:, 0] * k[:, 0] u[:, 0] = beta_[:, 0] * v[:, 0] for t in range(1,T): w[:, t] = beta_[:, t] * k.clone()[:, t] - beta_[:, t] * einsum('ns,nsd->nd', K[:, :t, t], w[:, :t].clone()) u[:, t] = beta_[:, t] * v.clone()[:, t] - beta_[:, t] * einsum('ns,nsd->nd', K[:, :t, t], u[:, :t].clone()) return w, u @@ -340,6 +341,7 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta): NH, T, D = shape(None, k, v, beta) DV = D K = einsum('nsd,ntd->nst', k, k) # (T,T) matrix beta_ = beta.unsqueeze(-1) w = k.new_zeros(NH, T, D) u = v.new_zeros(NH, T, D) @@ -353,11 +355,11 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta): u_t = b_t v_t - b_t \sum_{s=0}^{t-1} k_s^T k_t u_s """ w[:, 0] = beta_[:, 0] * k[:, 0] u[:, 0] = beta_[:, 0] * v[:, 0] for t in range(1,T): w[:, t] = beta_[:, t] * k.clone()[:, t] - beta_[:, t] * einsum('nsj,nj,nsd->nd', k[:, :t], k[:, t], w[:, :t].clone()) u[:, t] = beta_[:, t] * v.clone()[:, t] - beta_[:, t] * einsum('nsj,nj,nsd->nd', k[:, :t], k[:, t], u[:, :t].clone()) """ b0 k0 b1 k1 b2 k2 @@ -402,14 +404,14 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta): for l in range(t): c = (k[:, t] * k[:, l]).sum(-1) if s < l: WK[:, t, s] -= einsum('n,n,nij->nij', beta[:, t], c, WK[:, l, s]) if s == l: WK[:, t, s] -= einsum('n,nj,ni->nij', beta[:, t], k[:, t], w[:, s]) WK[:, t, s] -= einsum('n,n,nij->nij', beta[:, t], c, WK[:, l, l]) # [s=t] WK[:, t, t, arange(D), arange(D)] = beta_[:, t] WK[:, t, t] += einsum('n,nsj,nsi->nij', -beta[:, t], k[:, :t], w[:, :t]) """ u_t = b_t v_t - b_t \sum_{s=0}^{t-1} k_s^T k_t u_s @@ -425,7 +427,7 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta): UV[:, t, :t] = einsum('n,nt,ntsv->nsv', -beta[:, t], K[:, :t, t], UV[:, :t, :t]) # [s=t] UV[:, t, t] = beta_[:, t] """ d u_t / d k_s = @@ -444,20 +446,20 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta): for l in range(t): c = (k[:, t] * k[:, l]).sum(-1) if s < l: UK[:, t, s] -= einsum('n,n,nij->nij', beta[:, t], c, UK[:, l, s]) if s == l: UK[:, t, s] -= einsum('n,nk,nv->nkv', beta[:, t], k[:, t], u[:, s]) UK[:, t, s] -= einsum('n,n,nij->nij', beta[:, t], c, UK[:, l, l]) # [s=t] UK[:, t, t] -= einsum('n,nsk,nsv->nkv', beta[:, t], k[:, :t], u[:, :t]) UB = u.new_zeros(NH, T, S, D) # d u_t / d beta_s for t in range(T): for s in range(t): for l in range(t): c = (k[:, t] * k[:, l]).sum(-1) UB[:, t, s] -= einsum('n,n,nj->nj', beta[:, t], c, UB[:, l, s]) UB[:, t, t] = v[:, t] - einsum('ns,nsd->nd', K[:, :t, t], u[:, :t]) @@ -482,7 +484,7 @@ def backward(ctx, d_out_w, d_out_u): def test_equal_decay_values_backward(): NH, T, D = 2, 8, 3 q, k, v, beta = make_example(NH, T, D) w, u = decay_values_forward(k, v, beta) @@ -669,22 +671,60 @@ def backward(ctx, d_y): chunk_size = ctx.chunk_size C = T // chunk_size q_, k_, v_, beta_ = ( q.view(NH*C, chunk_size, D), k.view(NH*C, chunk_size, D), v.view(NH*C, chunk_size, D), beta.view(NH*C, chunk_size) ) w, u = decay_values(k_, v_, beta_) d_q_1, d_k_1, d_w, d_u1 = stitch_backward(d_y, q, k, w, u, C=C, chunk_size=chunk_size) d_w = d_w.view(NH*C, chunk_size, D) d_u1 = d_u1.view(NH*C, chunk_size, D) d_y = d_y.view(NH*C, chunk_size, D) u = u.view(NH*C, chunk_size, D) d_q_2, d_k_2, d_u2 = causal_attend_backward(d_y, q_, k_, u) d_u = d_u1 + d_u2 d_k_3, d_v_, d_beta_ = decay_values_backward(d_w, d_u, k_, v_, beta_) d_q_1 = d_q_1.view(NH, T, D) d_q_2 = d_q_2.view(NH, C, chunk_size, D) d_q_2 = d_q_2.reshape(NH, T, D) d_q = d_q_1 + d_q_2 d_k_2 = d_k_2.reshape(NH, T, D) d_k_3 = d_k_3.reshape(NH, T, D) d_k = d_k_1 + d_k_2 + d_k_3 d_v = d_v_.reshape(NH, T, D) d_beta = d_beta_.view(NH, T) return d_q, d_k, d_v, d_beta, None def test_delta_simple_backward(): NH, T, D = 2, 16, 2 q1, k1, v1, beta1 = make_example(NH, T, D) y1 = forward_chunkwise(q1, k1, v1, beta1, chunk_size=2) (y1 - torch.ones_like(y1).detach()).pow(2).mean().backward() q, k, v, beta = make_example(NH, T, D) y = DeltaChunkwise.apply(q, k, v, beta, 2) (y - torch.ones_like(y).detach()).pow(2).mean().backward() assert allclose(y1, y, atol=1e-5), 'y is wrong' # print(beta1.grad - beta.grad, 'beta.grad diff') # print(q1.grad - q.grad, 'q.grad diff') # print(k1.grad - k.grad, 'k.grad diff') # print(v1.grad - v.grad, 'v.grad diff') assert allclose(q1.grad, q.grad, atol=1e-5), 'q.grad is wrong' assert allclose(beta1.grad, beta.grad, atol=1e-5), 'beta.grad is wrong' assert allclose(k1.grad, k.grad, atol=1e-5), 'k.grad is wrong' assert allclose(v1.grad, v.grad, atol=1e-5), 'v.grad is wrong' test_delta_simple_backward() -
proger revised this gist
Jul 2, 2024 . 1 changed file with 16 additions and 7 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -385,7 +385,16 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta): WB[:, t, :t] = einsum('n,nt,ntsk->nsk', -beta[:, t], K[:, :t, t], WB[:, :t, :t]) WB[:, t, t] = k[:, t] - einsum('nt,ntk->nk', K[:, :t, t], w[:, :t]) """ w_0 = b_0 k_0 w_1 = b_1 k_1 - b_1 k_0^T k_1 w_0 w_2 = b_2 k_2 - b_2 k_0^T k_2 w_0 - b_2 k_1^T k_2 w_1 w_t = b_t k_t - b_t \sum_{s=0}^{t-1} k_s^T k_t w_s u_t = b_t v_t - b_t \sum_{s=0}^{t-1} k_s^T k_t u_s """ WK = w.new_zeros(NH, T, S, D, D) # d w_t / d k_s # ntsij for t in range(T): # [s<t] @@ -395,11 +404,12 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta): if s < l: WK[:, t, s] -= beta[:, t] * c * WK[:, l, s] if s == l: WK[:, t, s] -= beta[:, t] * einsum('nj,ni->nij', k[:, t], w[:, s]) WK[:, t, s] -= beta[:, t] * c * WK[:, l, l] # [s=t] WK[:, t, t, arange(D), arange(D)] = beta[:, t] WK[:, t, t] += -beta[:, t] * einsum('nsj,nsi->nij', k[:, :t], w[:, :t]) """ u_t = b_t v_t - b_t \sum_{s=0}^{t-1} k_s^T k_t u_s @@ -452,8 +462,7 @@ def decay_values_backward(d_out_w, d_out_u, k, v, beta): UB[:, t, t] = v[:, t] - einsum('ns,nsd->nd', K[:, :t, t], u[:, :t]) d_beta = einsum('ntk,ntsk->ns', d_out_w, WB) + einsum('ntv,ntsv->ns', d_out_u, UB) # sum T and D out d_k = einsum('nti,ntsij->nsj', d_out_w, WK) + einsum('ntv,ntskv->nsk', d_out_u, UK) # sum T out d_v = einsum('ntv,ntsv->nsv', d_out_u, UV) # sum T out return d_k, d_v, d_beta @@ -473,7 +482,7 @@ def backward(ctx, d_out_w, d_out_u): def test_equal_decay_values_backward(): NH, T, D = 1, 8, 3 q, k, v, beta = make_example(NH, T, D) w, u = decay_values_forward(k, v, beta) -
proger revised this gist
Jul 1, 2024 . 1 changed file with 101 additions and 201 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -336,8 +336,9 @@ def decay_values_forward(k, v, beta): return w, u @no_grad() def decay_values_backward(d_out_w, d_out_u, k, v, beta): NH, T, D = shape(None, k, v, beta) DV = D K = einsum('nsd,ntd->nst', k, k) # (T,T) matrix w = k.new_zeros(NH, T, D) @@ -381,24 +382,24 @@ def decay_values_backward(k, v, beta): WB = w.new_zeros(NH, T, S, D) # d w_t / d beta_s for t in range(T): WB[:, t, :t] = einsum('n,nt,ntsk->nsk', -beta[:, t], K[:, :t, t], WB[:, :t, :t]) WB[:, t, t] = k[:, t] - einsum('nt,ntk->nk', K[:, :t, t], w[:, :t]) WK = w.new_zeros(NH, T, S, D, D) # d w_t / d k_s for t in range(T): # [s<t] for s in range(t): for l in range(t): c = (k[:, t] * k[:, l]).sum(-1) if s < l: WK[:, t, s] -= beta[:, t] * c * WK[:, l, s] if s == l: WK[:, t, s] -= beta[:, t] * einsum('nk,nw->nkw', k[:, t], w[:, s]) WK[:, t, s] -= beta[:, t] * c * WK[:, l, l] # [s=t] WK[:, t, t] = beta[:, t] WK[:, t, t] += -beta[:, t] * einsum('nsk,nsw->nkw', k[:, :t], w[:, :t]) """ u_t = b_t v_t - b_t \sum_{s=0}^{t-1} k_s^T k_t u_s @@ -407,11 +408,11 @@ def decay_values_backward(k, v, beta): + [s<t] - b_t \sum_{j=0}^{t-1} k_j^T k_t (d u_j / d v_s) """ UV = u.new_zeros(NH, T, S, DV) # d u_t / d v_s for t in range(T): # [s<t] UV[:, t, :t] = einsum('n,nt,ntsv->nsv', -beta[:, t], K[:, :t, t], UV[:, :t, :t]) # [s=t] UV[:, t, t] = beta[:, t] @@ -425,100 +426,8 @@ def decay_values_backward(k, v, beta): [s<t][s=l]: product rule for u_l(k_s) and k_l=k_s """ UK = u.new_zeros(NH, T, S, D, DV) # d u_t / d k_s for t in range(T): # [s<t] for s in range(t): @@ -542,116 +451,55 @@ def decay_values_backward_only_u(d_out, k, v, beta): UB[:, t, t] = v[:, t] - einsum('ns,nsd->nd', K[:, :t, t], u[:, :t]) d_beta = einsum('ntk,ntsk->ns', d_out_w, WB) + einsum('ntv,ntsv->ns', d_out_u, UB) # sum T and D out ##### what is wrong with WK??? d_k = einsum('ntk,ntskw->nsk', d_out_w, WK) + einsum('ntv,ntskv->nsk', d_out_u, UK) # sum T out d_v = einsum('ntv,ntsv->nsv', d_out_u, UV) # sum T out return d_k, d_v, d_beta class DecayValues(torch.autograd.Function): @staticmethod def forward(ctx, k, v, beta): w, u = decay_values_forward(k, v, beta) ctx.save_for_backward(k, v, beta) return w, u @staticmethod def backward(ctx, d_out_w, d_out_u): k, v, beta = ctx.saved_tensors return decay_values_backward(d_out_w, d_out_u, k, v, beta) def test_equal_decay_values_backward(): NH, T, D = 1, 4, 2 q, k, v, beta = make_example(NH, T, D) w, u = decay_values_forward(k, v, beta) #(w + u - torch.ones_like(w)).pow(2).mean().backward() (w - torch.ones_like(w)).pow(2).mean().backward() q1, k1, v1, beta1 = make_example(NH, T, D) w1, u1 = DecayValues.apply(k1, v1, beta1) #(w1 + u1 - torch.ones_like(w1)).pow(2).mean().backward() (w1 - torch.ones_like(w1)).pow(2).mean().backward() # print(v.grad, 'v.grad', v.grad.shape) # print(v1.grad, 'v1.grad') # print(v.grad - v1.grad, 'v diff') # assert allclose(v.grad, v1.grad, atol=1e-5), 'v1_grad is wrong' print(k.grad, 'k.grad du') print(k1.grad, 'k1.grad du') print(k.grad - k1.grad, 'diff du') assert allclose(k.grad, k1.grad, atol=1e-5), 'k1_grad is wrong' # print(beta.grad, 'beta.grad du') # print(beta1.grad, 'beta1.grad du') assert allclose(beta.grad, beta1.grad, atol=1e-5), 'beta1.grad is wrong' test_equal_decay_values_backward() # %% @@ -778,4 +626,56 @@ def test_stitch_all(atol=1e-5): assert allclose(w.grad, w1.grad, atol=atol), 'w.grad is wrong' test_stitch_all() #%% class DeltaChunkwise(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, beta, chunk_size): NH, T, D = shape(q, k, v, beta) C = T // chunk_size q_, k_, v_, beta_ = ( q.view(NH*C, chunk_size, D), k.view(NH*C, chunk_size, D), v.view(NH*C, chunk_size, D), beta.view(NH*C, chunk_size) ) # evaluate all chunks in parallel w, u = decay_values(k_, v_, beta_) y = causal_attend(q_, k_, u) # stitch chunks sequentially y_delta = forward_stitch(q, k, w, u, C=C, chunk_size=chunk_size) ctx.save_for_backward(q, k, v, beta) ctx.chunk_size = chunk_size return y.view(NH, T, D) + y_delta.view(NH, T, D) @staticmethod def backward(ctx, d_y): q, k, v, beta = ctx.saved_tensors NH, T, D = shape(q, k, v, beta) chunk_size = ctx.chunk_size C = T // chunk_size d_y = d_y.view(NH, C, chunk_size, D) q_, k_, v_, beta_ = ( q.view(NH*C, chunk_size, D), k.view(NH*C, chunk_size, D), v.view(NH*C, chunk_size, D), beta.view(NH*C, chunk_size) ) w, u = decay_values(k_, v_, beta_) d_q_1, d_k_1, d_w, d_u1 = stitch_backward(d_y, q_, k_, w, u, C=C, chunk_size=chunk_size) d_q_2, d_k_2, d_u2 = causal_attend_backward(d_y, q_, k_, u) d_u = d_u1 + d_u2 d_k_3, d_v_, d_beta_ = decay_values_backward(d_w, d_u, k_, v_, beta_) d_q = (d_q_1 + d_q_2).view(NH, T, D) d_k = (d_k_1 + d_k_2 + d_k_3).view(NH, T, D) d_v = d_v_.view(NH, T, D) d_beta = d_beta_.view(NH, T) return d_q, d_k, d_v, d_beta, None -
proger revised this gist
Jul 1, 2024 . 1 changed file with 218 additions and 11 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -96,28 +96,58 @@ def forward_stitch(q, k, w, u, C, chunk_size): u = u.view(NH, C, chunk_size, D) w = w.view(NH, C, chunk_size, D) y_delta = u.new_zeros(NH, C, chunk_size, D) # materialize the state for the leading chunk state = einsum('ntv,ntk->nvk', u[:, 0], k_[:, 0]) for c in range(1, C): y_delta[:, c], state = stitch1_forward(state, q_[:, c], k_[:, c], w[:, c], u[:, c]) return y_delta def stitch1_forward(state, q, k, w, u): state_decays = einsum('nvk,ntk->ntv', state, w) state_add = einsum('ntv,ntk->nvk', u - state_decays, k) delta = causal_attend(q, k, state_decays) prev_output = einsum('nvk,nsk->nsv', state, q) y = prev_output - delta return y, state + state_add def stitch1_backward(d_y, d_state, state, q, k, w, u): state_decays = einsum('nvk,ntk->ntv', state, w) d_q1 = einsum('nsv,nvk->nsk', d_y, state) # prev_output d_state1 = einsum('nsv,nsk->nvk', d_y, q) # prev_output d_q2, d_k1, d_state_decays1 = causal_attend_backward(-d_y, q, k, state_decays) # delta d_k2 = einsum('nvk,ntv->ntk', d_state, u - state_decays) # state_add d_u = einsum('nvk,ntk->ntv', d_state, k) # state_add d_state_decays2 = einsum('nvk,ntk->ntv', -d_state, k) # state_add d_state_decays = d_state_decays1 + d_state_decays2 d_w = einsum('ntv,nvk->ntk', d_state_decays, state) # state_decays d_state2 = einsum('ntv,ntk->nvk', d_state_decays, w) # state_decays return d_state + d_state1 + d_state2, d_q1 + d_q2, d_k1 + d_k2, d_w, d_u class Stitch1(torch.autograd.Function): @staticmethod def forward(ctx, state, q, k, w, u): y, new_state = stitch1_forward(state, q, k, w, u) ctx.save_for_backward(state, q, k, w, u) return y, new_state @staticmethod def backward(ctx, d_y, d_state): state, q, k, w, u = ctx.saved_tensors return stitch1_backward(d_y, d_state, state, q, k, w, u) def forward_ogloop(q, k, v, beta): @@ -571,4 +601,181 @@ def test_delta_simple_backward(): assert allclose(v1.grad, v.grad, atol=1e-5), 'v.grad is wrong' test_delta_simple_backward() # %% def test_stitch1(atol=1e-5): NH, T, D = 1, 4, 2 q, k, v, beta = make_example(NH, T, D) w, u = decay_values_forward(k, v, beta) state = torch.randn_like(einsum('ntv,ntk->nvk', v, k)) q.requires_grad_() k.requires_grad_() v.requires_grad_() beta.requires_grad_() w.retain_grad() u.retain_grad() state.requires_grad_() y, new_state = stitch1_forward(state, q, k, w, u) loss = (y - torch.ones_like(y)).pow(2).mean() + (new_state - torch.ones_like(new_state)).pow(2).mean() loss.backward() q1, k1, v1, beta1 = make_example(NH, T, D) w1, u1 = decay_values_forward(k1, v1, beta1) state1 = torch.randn_like(einsum('ntv,ntk->nvk', v1, k1)) q1.requires_grad_() k1.requires_grad_() v1.requires_grad_() beta1.requires_grad_() w1.retain_grad() u1.retain_grad() state1.requires_grad_() y1, new_state1 = Stitch1.apply(state1, q1, k1, w1, u1) loss = (y1 - torch.ones_like(y1)).pow(2).mean() + (new_state1 - torch.ones_like(new_state1)).pow(2).mean() loss.backward() assert allclose(y, y1, atol=atol), 'y is wrong' assert allclose(new_state, new_state1, atol=atol), 'new_state is wrong' assert allclose(u.grad, u1.grad, atol=atol), 'u.grad is wrong' assert allclose(v.grad, v1.grad, atol=atol), 'v.grad is wrong' assert allclose(k.grad, k1.grad, atol=atol), 'k.grad is wrong' assert allclose(q.grad, q1.grad, atol=atol), 'q.grad is wrong' assert allclose(beta.grad, beta1.grad, atol=atol), 'beta.grad is wrong' assert allclose(w.grad, w1.grad, atol=atol), 'w.grad is wrong' assert allclose(state.grad, state1.grad, atol=atol), 'state.grad is wrong' test_stitch1() # %% def stitch_forward(q, k, w, u, C, chunk_size): "stitch chunks sequentially" NH, T, D = shape(q, k, None, None) q_ = q.view(NH, C, chunk_size, D) k_ = k.view(NH, C, chunk_size, D) u = u.view(NH, C, chunk_size, D) w = w.view(NH, C, chunk_size, D) # materialize the state for the leading chunk state = einsum('ntv,ntk->nvk', u[:, 0], k_[:, 0]) deltas = [u.new_zeros(NH, chunk_size, D)] for c in range(1, C): y_delta1, state = stitch1_forward(state, q_[:, c], k_[:, c], w[:, c], u[:, c]) deltas.append(y_delta1) y_delta = torch.stack(deltas, dim=1) return y_delta, state def stitch_backward(d_y_delta, q, k, w, u, C, chunk_size): NH, T, D = shape(q, k, None, None) d_y_delta = d_y_delta.view(NH, C, chunk_size, D) q_ = q.view(NH, C, chunk_size, D) k_ = k.view(NH, C, chunk_size, D) u = u.view(NH, C, chunk_size, D) w = w.view(NH, C, chunk_size, D) d_q_ = q.new_zeros(NH, C, chunk_size, D) d_k_ = k.new_zeros(NH, C, chunk_size, D) d_w = w.new_zeros(NH, C, chunk_size, D) d_u = u.new_zeros(NH, C, chunk_size, D) d_state = w.new_zeros(NH, D, D) # NHVK y_delta = u.new_zeros(NH, C, chunk_size, D) # leading chunk has zero delta # storing all states for BPTT states = k.new_zeros(NH, C, D, D) # NHCVK # materialize the state for the leading chunk states[:, 0] = einsum('ntv,ntk->nvk', u[:, 0], k_[:, 0]) for c in range(1, C): y_delta[:, c], states[:, c] = stitch1_forward(states[:, c-1], q_[:, c], k_[:, c], w[:, c], u[:, c]) for c in range(C-1, 0, -1): ( d_state, d_q_[:, c], d_k_[:, c], d_w[:, c], d_u[:, c] ) = stitch1_backward(d_y_delta[:, c], d_state, states[:, c-1], q_[:, c], k_[:, c], w[:, c], u[:, c]) ( d_state, d_q_[:, 0], d_k_[:, 0], d_w[:, 0], d_u[:, 0] ) = stitch1_backward(d_y_delta[:, 0], d_state, torch.zeros_like(d_state), q_[:, 0], k_[:, 0], w[:, 0], u[:, 0]) return d_q_.view(NH, T, D), d_k_.view(NH, T, D), d_w.view(NH, T, D), d_u.view(NH, T, D) class Stitch(torch.autograd.Function): @staticmethod def forward(ctx, q, k, w, u, C, chunk_size): y_delta, state = stitch_forward(q, k, w, u, C, chunk_size) ctx.save_for_backward(q, k, w, u) ctx.C = C ctx.chunk_size = chunk_size return y_delta @staticmethod def backward(ctx, d_y_delta): q, k, w, u = ctx.saved_tensors return *stitch_backward(d_y_delta, q, k, w, u, ctx.C, ctx.chunk_size), None, None def test_stitch_all(atol=1e-5): NH, T, D = 1, 4, 2 C, chunk_size = 2, 2 q, k, v, beta = make_example(NH, T, D) w, u = decay_values_forward(k, v, beta) q.requires_grad_() k.requires_grad_() v.requires_grad_() beta.requires_grad_() w.retain_grad() u.retain_grad() y, new_state = stitch_forward(q, k, w, u, C=C, chunk_size=chunk_size) loss = (y - torch.ones_like(y)).pow(2).mean() loss.backward() # print(q.grad, 'q.grad') # print(k.grad, 'k.grad') # print(v.grad, 'v.grad') # print(beta.grad, 'beta.grad') # print(w.grad, 'w.grad') # print(u.grad, 'u.grad') q1, k1, v1, beta1 = make_example(NH, T, D) w1, u1 = decay_values_forward(k1, v1, beta1) q1.requires_grad_() k1.requires_grad_() v1.requires_grad_() beta1.requires_grad_() w1.retain_grad() u1.retain_grad() y1 = Stitch.apply(q1, k1, w1, u1, C, chunk_size) loss = (y1 - torch.ones_like(y1)).pow(2).mean() loss.backward() assert allclose(y, y1, atol=atol), 'y is wrong' assert allclose(u.grad, u1.grad, atol=atol), 'u.grad is wrong' assert allclose(v.grad, v1.grad, atol=atol), 'v.grad is wrong' # print(k.grad, 'k.grad') # print(k1.grad, 'k1.grad') assert allclose(k.grad, k1.grad, atol=atol), 'k.grad is wrong' assert allclose(q.grad, q1.grad, atol=atol), 'q.grad is wrong' assert allclose(beta.grad, beta1.grad, atol=atol), 'beta.grad is wrong' assert allclose(w.grad, w1.grad, atol=atol), 'w.grad is wrong' test_stitch_all() -
proger revised this gist
Jul 1, 2024 . 1 changed file with 2 additions and 4 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -103,9 +103,7 @@ def forward_stitch(q, k, w, u, C, chunk_size): mask = q.new_ones(chunk_size, chunk_size).tril() for c in range(1, C): y_delta[:, c], state = stitch1(state, q_[:, c], k_[:, c], w[:, c], u[:, c], mask) return y_delta @@ -119,7 +117,7 @@ def stitch1(state, q, k, w, u, mask): y_delta1 = prev_output - delta state_add = einsum('nsv,nsk->nvk', u - state_decays, k) return y_delta1, state + state_add def forward_ogloop(q, k, v, beta): -
proger revised this gist
Jul 1, 2024 . 1 changed file with 15 additions and 5 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -100,16 +100,26 @@ def forward_stitch(q, k, w, u, C, chunk_size): # materialize the state for the leading chunk state = einsum('ntv,ntk->nvk', u[:, 0], k_[:, 0]) mask = q.new_ones(chunk_size, chunk_size).tril() for c in range(1, C): y_delta1, state_add = stitch1(state, q_[:, c], k_[:, c], w[:, c], u[:, c], mask) y_delta[:, c] = y_delta1 state.add_(state_add) return y_delta def stitch1(state, q, k, w, u, mask): prev_output = einsum('nvk,nsk->nsv', state, q) state_decays = einsum('nvk,nsk->nsv', state, w) # delta = causal_attend(q, k, state_decays) delta = einsum('nsk,ntk,st,ntv->nsv', q, k, mask, state_decays) y_delta1 = prev_output - delta state_add = einsum('nsv,nsk->nvk', u - state_decays, k) return y_delta1, state_add def forward_ogloop(q, k, v, beta): -
proger revised this gist
Jul 1, 2024 . 1 changed file with 25 additions and 15 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -82,25 +82,34 @@ def forward_chunkwise(q, k, v, beta, chunk_size=2): w, u = decay_values(k_, v_, beta_) y = causal_attend(q_, k_, u) # stitch chunks sequentially y_delta = forward_stitch(q, k, w, u, C=C, chunk_size=chunk_size) return y.view(NH, T, D) + y_delta.view(NH, T, D) def forward_stitch(q, k, w, u, C, chunk_size): "stitch chunks sequentially" NH, T, D = shape(q, k, None, None) q_ = q.view(NH, C, chunk_size, D) k_ = k.view(NH, C, chunk_size, D) u = u.view(NH, C, chunk_size, D) w = w.view(NH, C, chunk_size, D) y_delta = q.new_zeros(NH, C, chunk_size, D) # materialize the state for the leading chunk state = einsum('ntv,ntk->nvk', u[:, 0], k_[:, 0]) for c in range(1, C): prev_output = einsum('nvk,ntk->ntv', state, q_[:, c]) state_decays = einsum('nvk,ntk->ntv', state, w[:, c]) y_delta[:, c] = prev_output - causal_attend(q_[:, c], k_[:, c], state_decays) state.add_(einsum('ntv,ntk->nvk', u[:, c] - state_decays, k_[:, c])) return y_delta def forward_ogloop(q, k, v, beta): @@ -153,7 +162,8 @@ def shape(q, k, v, beta=None): NH, T, D = (q if q is not None else k).shape if q is not None: assert q.shape == (NH, T, D) if v is not None: assert k.shape == v.shape if beta is not None: assert beta.shape == (NH, T) return NH, T, D -
proger revised this gist
Jul 1, 2024 . 1 changed file with 39 additions and 26 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -451,10 +451,11 @@ def test_equal_decay_values_backward(): @no_grad() def decay_values_backward_only_u(d_out, k, v, beta): "trimmed version of decay_values_backward that does not use w" NH, T, D = shape(None, k, v, beta) DV = D K = einsum('nsd,ntd->nst', k, k) # (T,T) matrix u = v.new_zeros(NH, T, D) @@ -469,7 +470,7 @@ def decay_values_backward_only_u(k, v, beta): UV[:, t, :t] = einsum('n,nt,ntsd->nsd', -beta[:, t], K[:, :t, t], UV[:, :t, :t]) UV[:, t, t] = beta[:, t] UK = u.new_zeros(NH, T, S, D, DV) # d u_t / d k_s for t in range(T): # [s<t] for s in range(t): @@ -478,10 +479,10 @@ def decay_values_backward_only_u(k, v, beta): if s < l: UK[:, t, s] -= beta[:, t] * c * UK[:, l, s] if s == l: UK[:, t, s] -= beta[:, t] * einsum('nk,nv->nkv', k[:, t], u[:, s]) UK[:, t, s] -= beta[:, t] * c * UK[:, l, l] # [s=t] UK[:, t, t] -= beta[:, t] * einsum('nsk,nsv->nkv', k[:, :t], u[:, :t]) UB = u.new_zeros(NH, T, S, D) # d u_t / d beta_s @@ -493,51 +494,63 @@ def decay_values_backward_only_u(k, v, beta): UB[:, t, t] = v[:, t] - einsum('ns,nsd->nd', K[:, :t, t], u[:, :t]) d_beta = einsum('ntv,ntsv->ns', d_out, UB) d_k = einsum('ntv,ntskv->nsk', d_out, UK) d_v = einsum('ntv,ntsv->nsv', d_out, UV) return u, d_k, d_v, d_beta def forward_simple1(q, k, v, beta): NH, T, D = shape(q, k, v, beta) K = einsum('nsd,ntd->nst', k, k) # (T,T) matrix u = einsum('nt,ntd->ntd', beta, v.clone()) for t in range(1,T): u[:, t] -= einsum('n,nt,ntd->nd', beta[:,t], K[:, :t, t], u[:, :t].clone()) return u, causal_attend(q, k, u) class DeltaSimple(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, beta): u, y = forward_simple1(q, k, v, beta) ctx.save_for_backward(q, k, v, beta, u) return y @staticmethod def backward(ctx, d_out): q, k, v, beta, u = ctx.saved_tensors d_q, d_k1, d_u = causal_attend_backward(d_out, q, k, u) u, d_k, d_v, d_beta = decay_values_backward_only_u(d_u, k, v, beta) d_k = d_k1 + d_k return d_q, d_k, d_v, d_beta def test_delta_simple_backward(): NH, T, D = 1, 3, 2 q1, k1, v1, beta1 = make_example(NH, T, D) u1, y1 = forward_simple1(q1, k1, v1, beta1) (y1 - torch.ones_like(y1).detach()).pow(2).mean().backward() q, k, v, beta = make_example(NH, T, D) y = DeltaSimple.apply(q, k, v, beta) (y - torch.ones_like(y).detach()).pow(2).mean().backward() assert allclose(y1, y, atol=1e-5), 'y is wrong' print(beta1.grad - beta.grad, 'beta.grad diff') #print(q1.grad - q.grad, 'q.grad diff') print(k1.grad - k.grad, 'k.grad diff') print(v1.grad - v.grad, 'v.grad diff') assert allclose(q1.grad, q.grad, atol=1e-5), 'q.grad is wrong' assert allclose(beta1.grad, beta.grad, atol=1e-5), 'beta.grad is wrong' assert allclose(k1.grad, k.grad, atol=1e-5), 'k.grad is wrong' assert allclose(v1.grad, v.grad, atol=1e-5), 'v.grad is wrong' test_delta_simple_backward() -
proger revised this gist
Jul 1, 2024 . 1 changed file with 95 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -446,3 +446,98 @@ def test_equal_decay_values_backward(): assert allclose(beta.grad, beta_grad, atol=1e-5), 'beta_grad wrt u is wrong' test_equal_decay_values_backward() #%% @no_grad() def decay_values_backward_only_u(k, v, beta): "trimmed version of decay_values_backward that does not use w" NH, T, D = shape(None, k, v, beta) K = einsum('nsd,ntd->nst', k, k) # (T,T) matrix u = v.new_zeros(NH, T, D) u[:, 0] = beta[:, 0] * v[:, 0] for t in range(1,T): u[:, t] = beta[:, t] * v.clone()[:, t] - beta[:, t] * einsum('nsj,nj,nsd->nd', k[:, :t], k[:, t], u[:, :t].clone()) S = T UV = u.new_zeros(NH, T, S, D) # d u_t / d v_s for t in range(T): UV[:, t, :t] = einsum('n,nt,ntsd->nsd', -beta[:, t], K[:, :t, t], UV[:, :t, :t]) UV[:, t, t] = beta[:, t] UK = u.new_zeros(NH, T, S, D) # d u_t / d k_s for t in range(T): # [s<t] for s in range(t): for l in range(t): c = (k[:, t] * k[:, l]).sum(-1) if s < l: UK[:, t, s] -= beta[:, t] * c * UK[:, l, s] if s == l: UK[:, t, s] -= beta[:, t] * einsum('nd,nh->nd', k[:, t], u[:, s]) UK[:, t, s] -= beta[:, t] * c * UK[:, l, l] # [s=t] UK[:, t, t] -= beta[:, t] * einsum('nsd,nsh->nd', k[:, :t], u[:, :t]) UB = u.new_zeros(NH, T, S, D) # d u_t / d beta_s for t in range(T): for s in range(t): for l in range(t): c = (k[:, t] * k[:, l]).sum(-1) UB[:, t, s] -= beta[:, t] * c * UB[:, l, s] UB[:, t, t] = v[:, t] - einsum('ns,nsd->nd', K[:, :t, t], u[:, :t]) d_beta = einsum('ntsd->ns', UB) # sum T and D out d_k = einsum('ntsd->nsd', UK) # sum T out d_v = einsum('ntsd->nsd', UV) # sum T out return u, d_k, d_v, d_beta def simple_backward(d_out, q, k, v, beta): u, du_dk, du_dv, du_dbeta = decay_values_backward_only_u(k, v, beta) d_q, d_k1, d_u = causal_attend_backward(d_out, q, k, u) d_k = d_k1 + d_u * du_dk d_v = d_out * d_u * du_dv d_beta = (d_out * d_u * du_dbeta.unsqueeze(-1)).sum(-1) return d_q, d_k, d_v, d_beta class DeltaSimple(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, beta): ctx.save_for_backward(q, k, v, beta) w, u = decay_values(k, v, beta) return causal_attend(q, k, u) @staticmethod def backward(ctx, d_out): q, k, v, beta = ctx.saved_tensors return simple_backward(d_out, q, k, v, beta) def test_delta_simple_backward(): NH, T, D = 1, 16, 2 q1, k1, v1, beta1 = make_example(NH, T, D) y1 = forward_simple(q1, k1, v1, beta1) (y1 - torch.ones_like(y1).detach()).pow(2).mean().backward() q, k, v, beta = make_example(NH, T, D) y = DeltaSimple.apply(q, k, v, beta) (y - torch.ones_like(y).detach()).pow(2).mean().backward() # print(beta1.grad - beta.grad, 'beta.grad diff') # print(k1.grad - k.grad, 'k.grad diff') # print(v1.grad - v.grad, 'v.grad diff') assert allclose(y1, y, atol=1e-5), 'y is wrong' test_delta_simple_backward() -
proger revised this gist
Jul 1, 2024 . 1 changed file with 50 additions and 24 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -159,14 +159,23 @@ def shape(q, k, v, beta=None): return NH, T, D def make_example(NH, T, D): manual_seed(0) q = randn(NH, T, D) / D**0.5 q.requires_grad_() k = randn(NH, T, D) / D**0.5 k.requires_grad_() v = randn(NH, T, D) / D**0.5 v.requires_grad_() beta = randn(NH, T).sigmoid() beta.requires_grad_() return q, k, v, beta def test_equal(atol=1e-6): NH, T, D = 2*3, 128, 16 #NH, T, D = 1, 8, 3 q, k, v, beta = make_example(NH, T, D) y1 = forward_ogloop(q, k, v, beta) y2 = forward_scanloop(q, k, v, beta) @@ -202,16 +211,21 @@ def causal_attend_backward(d_out, q, k, v, diagonal=0): return d_q, q_k, d_v class CausalAttend(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v): ctx.save_for_backward(q, k, v) return causal_attend(q, k, v) @staticmethod def backward(ctx, d_out): q, k, v = ctx.saved_tensors return causal_attend_backward(d_out, q, k, v) def test_equal_attend_backward(atol=1e-5): NH, T, D = 1*1, 512, 64 q, k, v, beta = make_example(NH, T, D) y = causal_attend(q, k, v) d_q, d_k, d_v = causal_attend_backward(torch.ones_like(y), q, k, v) @@ -227,7 +241,32 @@ def test_equal_attend_backward(atol=1e-5): # assert torch.allclose(g_hook.grad, d_g, atol=1e-1), 'g.grad is wrong' def test_equal_attend_backward2(atol=1e-5): NH, T, D = 1, 2, 2 q1, k1, v1, beta1 = make_example(NH, T, D) y1 = causal_attend(q1, k1, v1) (y1 - torch.ones_like(y1)).pow(2).mean().backward() q, k, v, beta = make_example(NH, T, D) y = CausalAttend.apply(q, k, v) (y - torch.ones_like(y)).pow(2).mean().backward() # print(q1.grad - q.grad, 'q.grad diff') # print(k1.grad - k.grad, 'k.grad diff') # print(v1.grad - v.grad, 'v.grad diff') # print(k.grad, 'k.grad') # print(k1.grad, 'k1.grad') # print(v.grad, 'v.grad') # print(v1.grad, 'v1.grad') assert (q1.grad - q.grad).abs().max() < atol, 'q.grad is wrong' assert (k1.grad - k.grad).abs().max() < atol, 'k.grad is wrong' assert (v1.grad - v.grad).abs().max() < atol, 'v.grad is wrong' test_equal_attend_backward() test_equal_attend_backward2() #%% @@ -372,19 +411,6 @@ def decay_values_backward(k, v, beta): return w, u, d_k, d_v, d_beta def test_equal_decay_values_backward(): NH, T, D = 1, 4, 2 -
proger revised this gist
Jul 1, 2024 . 1 changed file with 31 additions and 24 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -34,6 +34,7 @@ import os os.environ['TORCH_LOGS'] = 'output_code' import torch from torch import einsum, randn, allclose, stack, eye, manual_seed, no_grad, set_float32_matmul_precision, compile, arange set_float32_matmul_precision('high') @@ -49,8 +50,8 @@ def decay_values(k, v, beta): K = einsum('nsd,ntd->nst', k, k) # (T,T) matrix for t in range(1,T): w[:, t] -= beta_[:, t] * einsum('nt,ntd->nd', K[:, :t, t], w[:, :t].clone()) u[:, t] -= beta_[:, t] * einsum('nt,ntd->nd', K[:, :t, t], u[:, :t].clone()) return w, u @@ -185,19 +186,19 @@ def test_equal(atol=1e-6): @no_grad() def attend_backward(d_out, q, k, v, g): d_q = einsum('nsv,ntk,ntv,nst->nsk', d_out, k, v, g) d_k = einsum('nsv,nsk,ntv,nst->ntk', d_out, q, v, g) d_v = einsum('nsv,nsk,ntk,nst->ntv', d_out, q, k, g) d_g = einsum('nsv,nsk,ntk,ntv,nst->nst', d_out, q, k, v, g) return d_q, d_k, d_v, d_g @no_grad() def causal_attend_backward(d_out, q, k, v, diagonal=0): NH, T, D = shape(q, k, v) mask = q.new_ones(T, T).tril(diagonal=diagonal).unsqueeze(0) d_q, q_k, d_v, _d_mask = attend_backward(d_out, q, k, v, mask) return d_q, q_k, d_v @@ -213,7 +214,7 @@ def test_equal_attend_backward(atol=1e-5): v.requires_grad_() y = causal_attend(q, k, v) d_q, d_k, d_v = causal_attend_backward(torch.ones_like(y), q, k, v) y.sum().backward() assert allclose(q.grad, d_q, atol=atol), 'q.grad is wrong' @@ -353,7 +354,7 @@ def decay_values_backward(k, v, beta): # [s=t] UK[:, t, t] -= beta[:, t] * einsum('nsd,nsh->nd', k[:, :t], u[:, :t]) UB = u.new_zeros(NH, T, S, D) # d u_t / d beta_s for t in range(T): for s in range(t): @@ -368,36 +369,42 @@ def decay_values_backward(k, v, beta): d_v = einsum('ntsd->nsd', UV) # sum T out return w, u, d_k, d_v, d_beta def make_example(NH, T, D): manual_seed(0) q = randn(NH, T, D) / D**0.5 q.requires_grad_() k = randn(NH, T, D) / D**0.5 k.requires_grad_() v = randn(NH, T, D) / D**0.5 v.requires_grad_() beta = randn(NH, T).sigmoid() beta.requires_grad_() return q, k, v, beta def test_equal_decay_values_backward(): NH, T, D = 1, 4, 2 q, k, v, beta = make_example(NH, T, D) ## to try these modify summations in d_beta and d_k above # k, v, beta = make() # w, u = decay_values_forward(k, v, beta) # w.sum().backward() # w, u, k_grad, v_grad, beta_grad = decay_values_backward(k, v, beta) # # print(k.grad, 'k.grad') # # print(k_grad, 'k_grad') # # print(k.grad - k_grad, 'diff') # assert allclose(k.grad, k_grad, atol=1e-5), 'k_grad is wrong' # assert allclose(beta.grad, beta_grad, atol=1e-5), 'beta_grad is wrong' q, k, v, beta = make_example(NH, T, D) w, u = decay_values_forward(k, v, beta) (w + u).sum().backward() w, u, k_grad, v_grad, beta_grad = decay_values_backward(k, v, beta) # print(v.grad, 'v.grad', v.grad.shape) # print(v_grad, 'v_grad') # print(v.grad - v_grad, 'v diff') -
proger revised this gist
Jun 30, 2024 . 1 changed file with 67 additions and 20 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -289,16 +289,15 @@ def decay_values_backward(k, v, beta): """ S = T WB = w.new_zeros(NH, T, S, D) # d w_t / d beta_s for t in range(T): WB[:, t, :t] = einsum('n,nt,ntsd->nsd', -beta[:, t], K[:, :t, t], WB[:, :t, :t]) WB[:, t, t] = k[:, t] - einsum('nt,ntd->nd', K[:, :t, t], w[:, :t]) WK = w.new_zeros(NH, T, S, D) # d w_t / d k_s for t in range(T): # [s<t] for s in range(t): @@ -320,15 +319,53 @@ def decay_values_backward(k, v, beta): + [s<t] - b_t \sum_{j=0}^{t-1} k_j^T k_t (d u_j / d v_s) """ UV = u.new_zeros(NH, T, S, D) # d u_t / d v_s for t in range(T): # [s<t] UV[:, t, :t] = einsum('n,nt,ntsd->nsd', -beta[:, t], K[:, :t, t], UV[:, :t, :t]) # [s=t] UV[:, t, t] = beta[:, t] """ d u_t / d k_s = - b_t \sum_{l=0}^{t-1} (D_{k_s} k_l^T k_t u_l) D_{k_s} k_l^T k_t u_l = [s=t] k_l u_l^T # sum out dimensions of u_j after the outer product? D_{k_s} k_l^T k_t u_l = [s<t][s<l]: only u_l is a function of k_s [s<t][s=l]: product rule for u_l(k_s) and k_l=k_s """ UK = u.new_zeros(NH, T, S, D) # d u_t / d k_s for t in range(T): # [s<t] for s in range(t): for l in range(t): c = (k[:, t] * k[:, l]).sum(-1) if s < l: UK[:, t, s] -= beta[:, t] * c * UK[:, l, s] if s == l: UK[:, t, s] -= beta[:, t] * einsum('nd,nh->nd', k[:, t], u[:, s]) UK[:, t, s] -= beta[:, t] * c * UK[:, l, l] # [s=t] UK[:, t, t] -= beta[:, t] * einsum('nsd,nsh->nd', k[:, :t], u[:, :t]) UB = w.new_zeros(NH, T, S, D) # d u_t / d beta_s for t in range(T): for s in range(t): for l in range(t): c = (k[:, t] * k[:, l]).sum(-1) UB[:, t, s] -= beta[:, t] * c * UB[:, l, s] UB[:, t, t] = v[:, t] - einsum('ns,nsd->nd', K[:, :t, t], u[:, :t]) d_beta = einsum('ntsd->ns', WB) + einsum('ntsd->ns', UB) # sum T and D out d_k = einsum('ntsd->nsd', WK) + einsum('ntsd->nsd', UK) # sum T out d_v = einsum('ntsd->nsd', UV) # sum T out return d_k, d_v, d_beta @@ -346,23 +383,33 @@ def make(): beta.requires_grad_() return k, v, beta ## to try these modify summations in d_beta and d_k above # k, v, beta = make() # w, u = decay_values_forward(k, v, beta) # w.sum().backward() # k_grad, v_grad, beta_grad = decay_values_backward(k, v, beta) # # print(k.grad, 'k.grad') # # print(k_grad, 'k_grad') # # print(k.grad - k_grad, 'diff') # assert allclose(k.grad, k_grad, atol=1e-5), 'k_grad is wrong' # assert allclose(beta.grad, beta_grad, atol=1e-5), 'beta_grad is wrong' k, v, beta = make() w, u = decay_values_forward(k, v, beta) (w + u).sum().backward() k_grad, v_grad, beta_grad = decay_values_backward(k, v, beta) # print(v.grad, 'v.grad', v.grad.shape) # print(v_grad, 'v_grad') # print(v.grad - v_grad, 'v diff') assert allclose(v.grad, v_grad, atol=1e-5), 'v_grad is wrong' # print(k.grad, 'k.grad du') # print(k_grad, 'k_grad du') # print(k.grad - k_grad, 'diff du') assert allclose(k.grad, k_grad, atol=1e-5), 'k_grad wrt u is wrong' # print(beta.grad, 'beta.grad du') # print(beta_grad, 'beta_grad du') assert allclose(beta.grad, beta_grad, atol=1e-5), 'beta_grad wrt u is wrong' test_equal_decay_values_backward() -
proger revised this gist
Jun 30, 2024 . 1 changed file with 25 additions and 9 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -261,6 +261,8 @@ def decay_values_backward(k, v, beta): w_1 = b_1 k_1 - b_1 k_0^T k_1 w_0 w_2 = b_2 k_2 - b_2 k_0^T k_2 w_0 - b_2 k_1^T k_2 w_1 w_t = b_t k_t - b_t \sum_{s=0}^{t-1} k_s^T k_t w_s u_t = b_t v_t - b_t \sum_{s=0}^{t-1} k_s^T k_t u_s """ w[:, 0] = beta[:, 0] * k[:, 0] u[:, 0] = beta[:, 0] * v[:, 0] @@ -311,7 +313,21 @@ def decay_values_backward(k, v, beta): WK[:, t, t] = beta[:, t] WK[:, t, t] += -beta[:, t] * einsum('nsj,nsi->nj', k[:, :t], w[:, :t]) """ u_t = b_t v_t - b_t \sum_{s=0}^{t-1} k_s^T k_t u_s d u_t / d v_s = [s=t] b_t + [s<t] - b_t \sum_{j=0}^{t-1} k_j^T k_t (d u_j / d v_s) """ for t in range(T): # [s<t] UV[:, t, :t] = einsum('n,nt,ntsd->nsd', -beta[:, t], K[:, :t, t], UV[:, :t, :t]) # [s=t] UV[:, t, t] = beta[:, t] d_beta = einsum('ntsd->ns', WB) # sum T and D out d_k = einsum('ntsd->nsd', WK) # sum T out d_v = einsum('ntsd->nsd', UV) # sum T out @@ -340,13 +356,13 @@ def make(): assert allclose(k.grad, k_grad, atol=1e-5), 'k_grad is wrong' assert allclose(beta.grad, beta_grad, atol=1e-5), 'beta_grad is wrong' k, v, beta = make() w, u = decay_values_forward(k, v, beta) u.sum().backward() k_grad, v_grad, beta_grad = decay_values_backward(k, v, beta) print(v.grad, 'v.grad', v.grad.shape) print(v_grad, 'v_grad') print(v.grad - v_grad, 'v diff') assert allclose(v.grad, v_grad, atol=1e-5), 'v_grad is wrong' test_equal_decay_values_backward() -
proger revised this gist
Jun 30, 2024 . 1 changed file with 6 additions and 6 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -34,7 +34,7 @@ import os os.environ['TORCH_LOGS'] = 'output_code' from torch import einsum, randn, allclose, stack, eye, manual_seed, no_grad, set_float32_matmul_precision, compile, arange set_float32_matmul_precision('high') @@ -300,12 +300,12 @@ def decay_values_backward(k, v, beta): for t in range(T): # [s<t] for s in range(t): # for j in range(t): [s<j] WK[:, t, s] -= beta[:, t] * einsum('nj,njd->nd', K[:, s+1:t, t], WK[:, s+1:t, s]) # for s in range(t): for j in range(t): [s<t][s=j] WK[:, t, :t] -= beta[:, t] * einsum('nh,nsd->nsh', k[:, t], w[:, :t]) WK[:, t, :t] -= beta[:, t] * einsum('ns,nsd->nsd', K[:, :t, t], WK[:, arange(t), arange(t)]) # [s=t] WK[:, t, t] = beta[:, t]
NewerOlder