Skip to content

Instantly share code, notes, and snippets.

@yzh119
yzh119 / dgl-transformer.py
Created December 3, 2020 09:03
Efficient Sparse Transformer implementation with DGL's builtin operators
import dgl
import dgl.ops as ops
import numpy as np
import torch as th
import torch.nn as nn
class FFN(nn.Module):
def __init__(self, d_feat, d_ffn, dropout=0.1):
super().__init__()
self.linear_0 = nn.Linear(d_feat, d_ffn)