Skip to content

Instantly share code, notes, and snippets.

@mufeili
Last active August 22, 2019 03:42
Show Gist options
  • Select an option

  • Save mufeili/4b103b809a8bd1241daf6fae6c1631ac to your computer and use it in GitHub Desktop.

Select an option

Save mufeili/4b103b809a8bd1241daf6fae6c1631ac to your computer and use it in GitHub Desktop.

Revisions

  1. mufeili revised this gist Aug 22, 2019. 1 changed file with 157 additions and 155 deletions.
    312 changes: 157 additions & 155 deletions alchemy.py
    Original file line number Diff line number Diff line change
    @@ -26,6 +26,161 @@ def __init__(self, graph=None, label=None):
    self.graph = graph
    self.label = label

    def alchemy_nodes(mol):
    """Featurization for all atoms in a molecule. The atom indices
    will be preserved.
    Args:
    mol : rdkit.Chem.rdchem.Mol
    RDKit molecule object
    Returns
    atom_feats_dict : dict
    Dictionary for atom features
    """
    atom_feats_dict = defaultdict(list)
    is_donor = defaultdict(int)
    is_acceptor = defaultdict(int)

    fdef_name = osp.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
    mol_featurizer = ChemicalFeatures.BuildFeatureFactory(fdef_name)
    mol_feats = mol_featurizer.GetFeaturesForMol(mol)
    mol_conformers = mol.GetConformers()
    assert len(mol_conformers) == 1
    geom = mol_conformers[0].GetPositions()

    for i in range(len(mol_feats)):
    if mol_feats[i].GetFamily() == 'Donor':
    node_list = mol_feats[i].GetAtomIds()
    for u in node_list:
    is_donor[u] = 1
    elif mol_feats[i].GetFamily() == 'Acceptor':
    node_list = mol_feats[i].GetAtomIds()
    for u in node_list:
    is_acceptor[u] = 1

    num_atoms = mol.GetNumAtoms()
    for u in range(num_atoms):
    atom = mol.GetAtomWithIdx(u)
    symbol = atom.GetSymbol()
    atom_type = atom.GetAtomicNum()
    aromatic = atom.GetIsAromatic()
    hybridization = atom.GetHybridization()
    num_h = atom.GetTotalNumHs()
    atom_feats_dict['pos'].append(torch.FloatTensor(geom[u]))
    atom_feats_dict['node_type'].append(atom_type)

    h_u = []
    h_u += [
    int(symbol == x) for x in ['H', 'C', 'N', 'O', 'F', 'S', 'Cl']
    ]
    h_u.append(atom_type)
    h_u.append(is_acceptor[u])
    h_u.append(is_donor[u])
    h_u.append(int(aromatic))
    h_u += [
    int(hybridization == x)
    for x in (Chem.rdchem.HybridizationType.SP,
    Chem.rdchem.HybridizationType.SP2,
    Chem.rdchem.HybridizationType.SP3)
    ]
    h_u.append(num_h)
    atom_feats_dict['n_feat'].append(torch.FloatTensor(h_u))

    atom_feats_dict['n_feat'] = torch.stack(atom_feats_dict['n_feat'], dim=0)
    atom_feats_dict['pos'] = torch.stack(atom_feats_dict['pos'], dim=0)
    atom_feats_dict['node_type'] = torch.LongTensor(atom_feats_dict['node_type'])

    return atom_feats_dict

    def alchemy_edges(mol, self_loop=True):
    """Featurization for all bonds in a molecule. The bond indices
    will be preserved.
    Args:
    mol : rdkit.Chem.rdchem.Mol
    RDKit molecule object
    Returns
    bond_feats_dict : dict
    Dictionary for bond features
    """
    bond_feats_dict = defaultdict(list)

    mol_conformers = mol.GetConformers()
    assert len(mol_conformers) == 1
    geom = mol_conformers[0].GetPositions()

    num_atoms = mol.GetNumAtoms()
    for u in range(num_atoms):
    for v in range(num_atoms):
    if u == v and not self_loop:
    continue

    e_uv = mol.GetBondBetweenAtoms(u, v)
    if e_uv is None:
    bond_type = None
    else:
    bond_type = e_uv.GetBondType()
    bond_feats_dict['e_feat'].append([
    float(bond_type == x)
    for x in (Chem.rdchem.BondType.SINGLE,
    Chem.rdchem.BondType.DOUBLE,
    Chem.rdchem.BondType.TRIPLE,
    Chem.rdchem.BondType.AROMATIC, None)
    ])
    bond_feats_dict['distance'].append(
    np.linalg.norm(geom[u] - geom[v]))

    bond_feats_dict['e_feat'] = torch.FloatTensor(
    bond_feats_dict['e_feat'])
    bond_feats_dict['distance'] = torch.FloatTensor(
    bond_feats_dict['distance']).reshape(-1, 1)

    return bond_feats_dict

    def mol_to_dgl(mol, self_loop=False):
    """
    Read sdf file and convert to dgl_graph
    Args:
    mol: Chem.rdchem.Mol
    self_loop: Whether to add self loop
    Returns:
    g: DGLGraph
    """
    g = dgl.DGLGraph()

    # add nodes
    num_atoms = mol.GetNumAtoms()
    atom_feats = alchemy_nodes(mol)
    g.add_nodes(num=num_atoms, data=atom_feats)

    # add edges
    # The model we were interested assumes a complete graph.
    # If this is not the case, do the code below instead
    #
    # for bond in mol.GetBonds():
    # u = bond.GetBeginAtomIdx()
    # v = bond.GetEndAtomIdx()
    if self_loop:
    g.add_edges(
    [i for i in range(num_atoms) for j in range(num_atoms)],
    [j for i in range(num_atoms) for j in range(num_atoms)])
    else:
    g.add_edges(
    [i for i in range(num_atoms) for j in range(num_atoms - 1)], [
    j for i in range(num_atoms)
    for j in range(num_atoms) if i != j
    ])

    bond_feats = alchemy_edges(mol, self_loop)
    g.edata.update(bond_feats)
    return g

    def batch_mol_to_dgl(batch_mols, num_processes):
    with Pool(processes=num_processes) as pool:
    results = pool.map(mol_to_dgl, batch_mols,
    chunksize=(len(batch_mols) // num_processes + 1))
    return results

    def batcher():
    def batcher_dev(batch):
    graphs, labels = zip(*batch)
    @@ -38,118 +193,6 @@ class TencentAlchemyDataset(Dataset):
    fdef_name = osp.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
    chem_feature_factory = ChemicalFeatures.BuildFeatureFactory(fdef_name)

    def alchemy_nodes(self, mol):
    """Featurization for all atoms in a molecule. The atom indices
    will be preserved.
    Args:
    mol : rdkit.Chem.rdchem.Mol
    RDKit molecule object
    Returns
    atom_feats_dict : dict
    Dictionary for atom features
    """
    atom_feats_dict = defaultdict(list)
    is_donor = defaultdict(int)
    is_acceptor = defaultdict(int)

    fdef_name = osp.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
    mol_featurizer = ChemicalFeatures.BuildFeatureFactory(fdef_name)
    mol_feats = mol_featurizer.GetFeaturesForMol(mol)
    mol_conformers = mol.GetConformers()
    assert len(mol_conformers) == 1
    geom = mol_conformers[0].GetPositions()

    for i in range(len(mol_feats)):
    if mol_feats[i].GetFamily() == 'Donor':
    node_list = mol_feats[i].GetAtomIds()
    for u in node_list:
    is_donor[u] = 1
    elif mol_feats[i].GetFamily() == 'Acceptor':
    node_list = mol_feats[i].GetAtomIds()
    for u in node_list:
    is_acceptor[u] = 1

    num_atoms = mol.GetNumAtoms()
    for u in range(num_atoms):
    atom = mol.GetAtomWithIdx(u)
    symbol = atom.GetSymbol()
    atom_type = atom.GetAtomicNum()
    aromatic = atom.GetIsAromatic()
    hybridization = atom.GetHybridization()
    num_h = atom.GetTotalNumHs()
    atom_feats_dict['pos'].append(torch.FloatTensor(geom[u]))
    atom_feats_dict['node_type'].append(atom_type)

    h_u = []
    h_u += [
    int(symbol == x) for x in ['H', 'C', 'N', 'O', 'F', 'S', 'Cl']
    ]
    h_u.append(atom_type)
    h_u.append(is_acceptor[u])
    h_u.append(is_donor[u])
    h_u.append(int(aromatic))
    h_u += [
    int(hybridization == x)
    for x in (Chem.rdchem.HybridizationType.SP,
    Chem.rdchem.HybridizationType.SP2,
    Chem.rdchem.HybridizationType.SP3)
    ]
    h_u.append(num_h)
    atom_feats_dict['n_feat'].append(torch.FloatTensor(h_u))

    atom_feats_dict['n_feat'] = torch.stack(atom_feats_dict['n_feat'],
    dim=0)
    atom_feats_dict['pos'] = torch.stack(atom_feats_dict['pos'], dim=0)
    atom_feats_dict['node_type'] = torch.LongTensor(atom_feats_dict['node_type'])

    return atom_feats_dict

    def alchemy_edges(self, mol, self_loop=True):
    """Featurization for all bonds in a molecule. The bond indices
    will be preserved.
    Args:
    mol : rdkit.Chem.rdchem.Mol
    RDKit molecule object
    Returns
    bond_feats_dict : dict
    Dictionary for bond features
    """
    bond_feats_dict = defaultdict(list)

    mol_conformers = mol.GetConformers()
    assert len(mol_conformers) == 1
    geom = mol_conformers[0].GetPositions()

    num_atoms = mol.GetNumAtoms()
    for u in range(num_atoms):
    for v in range(num_atoms):
    if u == v and not self_loop:
    continue

    e_uv = mol.GetBondBetweenAtoms(u, v)
    if e_uv is None:
    bond_type = None
    else:
    bond_type = e_uv.GetBondType()
    bond_feats_dict['e_feat'].append([
    float(bond_type == x)
    for x in (Chem.rdchem.BondType.SINGLE,
    Chem.rdchem.BondType.DOUBLE,
    Chem.rdchem.BondType.TRIPLE,
    Chem.rdchem.BondType.AROMATIC, None)
    ])
    bond_feats_dict['distance'].append(
    np.linalg.norm(geom[u] - geom[v]))

    bond_feats_dict['e_feat'] = torch.FloatTensor(
    bond_feats_dict['e_feat'])
    bond_feats_dict['distance'] = torch.FloatTensor(
    bond_feats_dict['distance']).reshape(-1, 1)

    return bond_feats_dict

    def _get_label(self, sdf_file):
    """Get molecule label.
    Args:
    @@ -174,44 +217,6 @@ def sdf_to_mol(self, sdf_file):
    mol = Chem.MolFromMolBlock(sdf, removeHs=False)
    return mol

    def mol_to_dgl(self, mol, self_loop=False):
    """
    Read sdf file and convert to dgl_graph
    Args:
    mol: Chem.rdchem.Mol
    self_loop: Whether to add self loop
    Returns:
    g: DGLGraph
    """
    g = dgl.DGLGraph()

    # add nodes
    num_atoms = mol.GetNumAtoms()
    atom_feats = self.alchemy_nodes(mol)
    g.add_nodes(num=num_atoms, data=atom_feats)

    # add edges
    # The model we were interested assumes a complete graph.
    # If this is not the case, do the code below instead
    #
    # for bond in mol.GetBonds():
    # u = bond.GetBeginAtomIdx()
    # v = bond.GetEndAtomIdx()
    if self_loop:
    g.add_edges(
    [i for i in range(num_atoms) for j in range(num_atoms)],
    [j for i in range(num_atoms) for j in range(num_atoms)])
    else:
    g.add_edges(
    [i for i in range(num_atoms) for j in range(num_atoms - 1)], [
    j for i in range(num_atoms)
    for j in range(num_atoms) if i != j
    ])

    bond_feats = self.alchemy_edges(mol, self_loop)
    g.edata.update(bond_feats)
    return g

    def __init__(self, mode='dev', num_processes=1, transform=None, **kwargs):
    assert mode in ['dev', 'valid',
    'test'], "mode should be dev/valid/test"
    @@ -256,10 +261,7 @@ def _load(self, num_processes, **kwargs):
    batch_id = 0
    while batch_id * batch_size < len(sdf_file_list):
    batch_mols = mols[batch_id * batch_size:(batch_id + 1) * batch_size]
    with Pool(processes=num_processes) as pool:
    self.graphs.extend(
    pool.map(self.mol_to_dgl, batch_mols,
    chunksize=(len(batch_mols) // num_processes + 1)))
    self.graphs.extend(batch_mol_to_dgl(batch_mols, num_processes))
    batch_id += 1

    self.normalize()
    @@ -291,7 +293,7 @@ def __getitem__(self, idx):
    parser = argparse.ArgumentParser()
    parser.add_argument('-np', '--num-processes', type=int, default=2,
    help='Number of subprocesses to use for graph construction and featurization')
    parser.add_argument('-b', '--batch-size', type=int, default=128,
    parser.add_argument('-b', '--batch-size', type=int, default=2,
    help='Batch size for processing')
    args = parser.parse_args()

  2. mufeili revised this gist Aug 22, 2019. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion alchemy.py
    Original file line number Diff line number Diff line change
    @@ -226,7 +226,7 @@ def __init__(self, mode='dev', num_processes=1, transform=None, **kwargs):
    archive.extractall('./Alchemy_data')
    archive.close()

    self._load(num_processes, kwargs)
    self._load(num_processes, **kwargs)

    def _load(self, num_processes, **kwargs):
    if self.mode == 'dev':
  3. mufeili revised this gist Aug 22, 2019. 1 changed file with 19 additions and 9 deletions.
    28 changes: 19 additions & 9 deletions alchemy.py
    Original file line number Diff line number Diff line change
    @@ -212,7 +212,7 @@ def mol_to_dgl(self, mol, self_loop=False):
    g.edata.update(bond_feats)
    return g

    def __init__(self, mode='dev', num_processes=1, transform=None):
    def __init__(self, mode='dev', num_processes=1, transform=None, **kwargs):
    assert mode in ['dev', 'valid',
    'test'], "mode should be dev/valid/test"
    self.mode = mode
    @@ -226,9 +226,9 @@ def __init__(self, mode='dev', num_processes=1, transform=None):
    archive.extractall('./Alchemy_data')
    archive.close()

    self._load(num_processes)
    self._load(num_processes, kwargs)

    def _load(self, num_processes):
    def _load(self, num_processes, **kwargs):
    if self.mode == 'dev':
    target_file = pathlib.Path(self.file_dir, "dev_target.csv")
    self.target = pd.read_csv(target_file,
    @@ -247,13 +247,20 @@ def _load(self, num_processes):
    self.graphs.append(self.mol_to_dgl(mol))
    self.labels.append(self._get_label(sdf_file))
    else:
    batch_size = kwargs.get('batch_size')
    sdf_file_list = list(sdf_dir.glob("**/*.sdf"))
    mols = []
    for sdf_file in sdf_dir.glob("**/*.sdf"):
    for sdf_file in sdf_file_list:
    mols.append(self.sdf_to_mol(sdf_file))
    self.labels.append(self._get_label(sdf_file))
    with Pool(processes=num_processes) as pool:
    self.graphs = pool.map(self.mol_to_dgl, mols,
    chunksize=(len(mols) // num_processes + 1))
    batch_id = 0
    while batch_id * batch_size < len(sdf_file_list):
    batch_mols = mols[batch_id * batch_size:(batch_id + 1) * batch_size]
    with Pool(processes=num_processes) as pool:
    self.graphs.extend(
    pool.map(self.mol_to_dgl, batch_mols,
    chunksize=(len(batch_mols) // num_processes + 1)))
    batch_id += 1

    self.normalize()
    print(len(self.graphs), "loaded!")
    @@ -282,12 +289,15 @@ def __getitem__(self, idx):
    import time

    parser = argparse.ArgumentParser()
    parser.add_argument('-np', '--num-processes', type=int, default=128,
    parser.add_argument('-np', '--num-processes', type=int, default=2,
    help='Number of subprocesses to use for graph construction and featurization')
    parser.add_argument('-b', '--batch-size', type=int, default=128,
    help='Batch size for processing')
    args = parser.parse_args()

    t1 = time.time()
    alchemy_dataset = TencentAlchemyDataset(num_processes=args.num_processes)
    alchemy_dataset = TencentAlchemyDataset(num_processes=args.num_processes,
    batch_size=args.batch_size)
    t2 = time.time()
    print('It took {} with {:d} processes.'.format(datetime.timedelta(seconds=t2 - t1),
    args.num_processes))
  4. mufeili revised this gist Aug 21, 2019. 1 changed file with 2 additions and 1 deletion.
    3 changes: 2 additions & 1 deletion alchemy.py
    Original file line number Diff line number Diff line change
    @@ -252,7 +252,8 @@ def _load(self, num_processes):
    mols.append(self.sdf_to_mol(sdf_file))
    self.labels.append(self._get_label(sdf_file))
    with Pool(processes=num_processes) as pool:
    self.graphs = pool.map(self.mol_to_dgl, mols, chunksize=1)
    self.graphs = pool.map(self.mol_to_dgl, mols,
    chunksize=(len(mols) // num_processes + 1))

    self.normalize()
    print(len(self.graphs), "loaded!")
  5. mufeili revised this gist Aug 21, 2019. 1 changed file with 2 additions and 1 deletion.
    3 changes: 2 additions & 1 deletion alchemy.py
    Original file line number Diff line number Diff line change
    @@ -169,7 +169,8 @@ def sdf_to_mol(self, sdf_file):
    Returns:
    mol: Chem.rdchem.Mol
    """
    sdf = open(str(sdf_file)).read()
    with open(str(sdf_file)) as f:
    sdf = f.read()
    mol = Chem.MolFromMolBlock(sdf, removeHs=False)
    return mol

  6. mufeili revised this gist Aug 21, 2019. 1 changed file with 3 additions and 6 deletions.
    9 changes: 3 additions & 6 deletions alchemy.py
    Original file line number Diff line number Diff line change
    @@ -280,15 +280,12 @@ def __getitem__(self, idx):
    import time

    parser = argparse.ArgumentParser()
    parser.add_argument('-np', '--num-processes', type=int, default=2,
    parser.add_argument('-np', '--num-processes', type=int, default=128,
    help='Number of subprocesses to use for graph construction and featurization')
    args = parser.parse_args()

    t1 = time.time()
    alchemy_dataset = TencentAlchemyDataset()
    t2 = time.time()

    alchemy_dataset = TencentAlchemyDataset(num_processes=args.num_processes)
    print('It took {} with single process.'.format(datetime.timedelta(seconds=t2 - t1)))
    print('It took {} with {:d} processes.'.format(datetime.timedelta(seconds=t3 - t2),
    t2 = time.time()
    print('It took {} with {:d} processes.'.format(datetime.timedelta(seconds=t2 - t1),
    args.num_processes))
  7. mufeili revised this gist Aug 21, 2019. 1 changed file with 1 addition and 4 deletions.
    5 changes: 1 addition & 4 deletions alchemy.py
    Original file line number Diff line number Diff line change
    @@ -243,10 +243,7 @@ def _load(self, num_processes):
    if num_processes == 1:
    for sdf_file in sdf_dir.glob("**/*.sdf"):
    mol = self.sdf_to_mol(sdf_file)
    result = self.mol_to_dgl(mol)
    if result is None:
    continue
    self.graphs.append(result[0])
    self.graphs.append(self.mol_to_dgl(mol))
    self.labels.append(self._get_label(sdf_file))
    else:
    mols = []
  8. mufeili revised this gist Aug 21, 2019. 1 changed file with 19 additions and 12 deletions.
    31 changes: 19 additions & 12 deletions alchemy.py
    Original file line number Diff line number Diff line change
    @@ -150,6 +150,18 @@ def alchemy_edges(self, mol, self_loop=True):

    return bond_feats_dict

    def _get_label(self, sdf_file):
    """Get molecule label.
    Args:
    sdf_file: path of sdf file
    Returns
    l : molecule label
    """
    # for val/test set, labels are molecule ID
    l = torch.FloatTensor(self.target.loc[int(sdf_file.stem)].tolist()) \
    if self.mode == 'dev' else torch.LongTensor([int(sdf_file.stem)])
    return l

    def sdf_to_mol(self, sdf_file):
    """Read sdf file and convert it to a rdkit molecule object.
    Args:
    @@ -169,7 +181,6 @@ def mol_to_dgl(self, mol, self_loop=False):
    self_loop: Whether to add self loop
    Returns:
    g: DGLGraph
    l: related labels
    """
    g = dgl.DGLGraph()

    @@ -198,11 +209,7 @@ def mol_to_dgl(self, mol, self_loop=False):

    bond_feats = self.alchemy_edges(mol, self_loop)
    g.edata.update(bond_feats)

    # for val/test set, labels are molecule ID
    l = torch.FloatTensor(self.target.loc[int(sdf_file.stem)].tolist()) \
    if self.mode == 'dev' else torch.LongTensor([int(sdf_file.stem)])
    return (g, l)
    return g

    def __init__(self, mode='dev', num_processes=1, transform=None):
    assert mode in ['dev', 'valid',
    @@ -240,14 +247,14 @@ def _load(self, num_processes):
    if result is None:
    continue
    self.graphs.append(result[0])
    self.labels.append(result[1])
    self.labels.append(self._get_label(sdf_file))
    else:
    mols = [self.sdf_to_mol(sdf_file) for sdf_file in sdf_dir.glob("**/*.sdf")]
    mols = []
    for sdf_file in sdf_dir.glob("**/*.sdf"):
    mols.append(self.sdf_to_mol(sdf_file))
    self.labels.append(self._get_label(sdf_file))
    with Pool(processes=num_processes) as pool:
    all_result = pool.map(self.mol_to_dgl, mols)
    for result in all_result:
    self.graphs.append(result[0])
    self.labels.append(result[1])
    self.graphs = pool.map(self.mol_to_dgl, mols, chunksize=1)

    self.normalize()
    print(len(self.graphs), "loaded!")
  9. mufeili revised this gist Aug 21, 2019. 1 changed file with 18 additions and 8 deletions.
    26 changes: 18 additions & 8 deletions alchemy.py
    Original file line number Diff line number Diff line change
    @@ -150,19 +150,27 @@ def alchemy_edges(self, mol, self_loop=True):

    return bond_feats_dict

    def sdf_to_dgl(self, sdf_file, self_loop=False):
    """
    Read sdf file and convert to dgl_graph
    def sdf_to_mol(self, sdf_file):
    """Read sdf file and convert it to a rdkit molecule object.
    Args:
    sdf_file: path of sdf file
    self_loop: Whetaher to add self loop
    Returns:
    g: DGLGraph
    l: related labels
    mol: Chem.rdchem.Mol
    """
    sdf = open(str(sdf_file)).read()
    mol = Chem.MolFromMolBlock(sdf, removeHs=False)
    return mol

    def mol_to_dgl(self, mol, self_loop=False):
    """
    Read sdf file and convert to dgl_graph
    Args:
    mol: Chem.rdchem.Mol
    self_loop: Whether to add self loop
    Returns:
    g: DGLGraph
    l: related labels
    """
    g = dgl.DGLGraph()

    # add nodes
    @@ -227,14 +235,16 @@ def _load(self, num_processes):
    self.graphs, self.labels = [], []
    if num_processes == 1:
    for sdf_file in sdf_dir.glob("**/*.sdf"):
    result = self.sdf_to_dgl(sdf_file)
    mol = self.sdf_to_mol(sdf_file)
    result = self.mol_to_dgl(mol)
    if result is None:
    continue
    self.graphs.append(result[0])
    self.labels.append(result[1])
    else:
    mols = [self.sdf_to_mol(sdf_file) for sdf_file in sdf_dir.glob("**/*.sdf")]
    with Pool(processes=num_processes) as pool:
    all_result = pool.map(self.sdf_to_dgl, list(sdf_dir.glob("**/*.sdf")))
    all_result = pool.map(self.mol_to_dgl, mols)
    for result in all_result:
    self.graphs.append(result[0])
    self.labels.append(result[1])
  10. mufeili created this gist Aug 21, 2019.
    280 changes: 280 additions & 0 deletions alchemy.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,280 @@
    # -*- coding:utf-8 -*-
    """Example dataloader of Tencent Alchemy Dataset
    https://alchemy.tencent.com/
    """
    import dgl
    import numpy as np
    import os
    import os.path as osp
    import pandas as pd
    import pathlib
    import torch
    import zipfile

    from collections import defaultdict
    from dgl.data.utils import download
    from multiprocessing import Pool
    from rdkit import Chem
    from rdkit.Chem import ChemicalFeatures
    from rdkit import RDConfig
    from torch.utils.data import Dataset
    from torch.utils.data import DataLoader
    _urls = {'Alchemy': 'https://alchemy.tencent.com/data/'}

    class AlchemyBatcher:
    def __init__(self, graph=None, label=None):
    self.graph = graph
    self.label = label

    def batcher():
    def batcher_dev(batch):
    graphs, labels = zip(*batch)
    batch_graphs = dgl.batch(graphs)
    labels = torch.stack(labels, 0)
    return AlchemyBatcher(graph=batch_graphs, label=labels)
    return batcher_dev

    class TencentAlchemyDataset(Dataset):
    fdef_name = osp.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
    chem_feature_factory = ChemicalFeatures.BuildFeatureFactory(fdef_name)

    def alchemy_nodes(self, mol):
    """Featurization for all atoms in a molecule. The atom indices
    will be preserved.
    Args:
    mol : rdkit.Chem.rdchem.Mol
    RDKit molecule object
    Returns
    atom_feats_dict : dict
    Dictionary for atom features
    """
    atom_feats_dict = defaultdict(list)
    is_donor = defaultdict(int)
    is_acceptor = defaultdict(int)

    fdef_name = osp.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
    mol_featurizer = ChemicalFeatures.BuildFeatureFactory(fdef_name)
    mol_feats = mol_featurizer.GetFeaturesForMol(mol)
    mol_conformers = mol.GetConformers()
    assert len(mol_conformers) == 1
    geom = mol_conformers[0].GetPositions()

    for i in range(len(mol_feats)):
    if mol_feats[i].GetFamily() == 'Donor':
    node_list = mol_feats[i].GetAtomIds()
    for u in node_list:
    is_donor[u] = 1
    elif mol_feats[i].GetFamily() == 'Acceptor':
    node_list = mol_feats[i].GetAtomIds()
    for u in node_list:
    is_acceptor[u] = 1

    num_atoms = mol.GetNumAtoms()
    for u in range(num_atoms):
    atom = mol.GetAtomWithIdx(u)
    symbol = atom.GetSymbol()
    atom_type = atom.GetAtomicNum()
    aromatic = atom.GetIsAromatic()
    hybridization = atom.GetHybridization()
    num_h = atom.GetTotalNumHs()
    atom_feats_dict['pos'].append(torch.FloatTensor(geom[u]))
    atom_feats_dict['node_type'].append(atom_type)

    h_u = []
    h_u += [
    int(symbol == x) for x in ['H', 'C', 'N', 'O', 'F', 'S', 'Cl']
    ]
    h_u.append(atom_type)
    h_u.append(is_acceptor[u])
    h_u.append(is_donor[u])
    h_u.append(int(aromatic))
    h_u += [
    int(hybridization == x)
    for x in (Chem.rdchem.HybridizationType.SP,
    Chem.rdchem.HybridizationType.SP2,
    Chem.rdchem.HybridizationType.SP3)
    ]
    h_u.append(num_h)
    atom_feats_dict['n_feat'].append(torch.FloatTensor(h_u))

    atom_feats_dict['n_feat'] = torch.stack(atom_feats_dict['n_feat'],
    dim=0)
    atom_feats_dict['pos'] = torch.stack(atom_feats_dict['pos'], dim=0)
    atom_feats_dict['node_type'] = torch.LongTensor(atom_feats_dict['node_type'])

    return atom_feats_dict

    def alchemy_edges(self, mol, self_loop=True):
    """Featurization for all bonds in a molecule. The bond indices
    will be preserved.
    Args:
    mol : rdkit.Chem.rdchem.Mol
    RDKit molecule object
    Returns
    bond_feats_dict : dict
    Dictionary for bond features
    """
    bond_feats_dict = defaultdict(list)

    mol_conformers = mol.GetConformers()
    assert len(mol_conformers) == 1
    geom = mol_conformers[0].GetPositions()

    num_atoms = mol.GetNumAtoms()
    for u in range(num_atoms):
    for v in range(num_atoms):
    if u == v and not self_loop:
    continue

    e_uv = mol.GetBondBetweenAtoms(u, v)
    if e_uv is None:
    bond_type = None
    else:
    bond_type = e_uv.GetBondType()
    bond_feats_dict['e_feat'].append([
    float(bond_type == x)
    for x in (Chem.rdchem.BondType.SINGLE,
    Chem.rdchem.BondType.DOUBLE,
    Chem.rdchem.BondType.TRIPLE,
    Chem.rdchem.BondType.AROMATIC, None)
    ])
    bond_feats_dict['distance'].append(
    np.linalg.norm(geom[u] - geom[v]))

    bond_feats_dict['e_feat'] = torch.FloatTensor(
    bond_feats_dict['e_feat'])
    bond_feats_dict['distance'] = torch.FloatTensor(
    bond_feats_dict['distance']).reshape(-1, 1)

    return bond_feats_dict

    def sdf_to_dgl(self, sdf_file, self_loop=False):
    """
    Read sdf file and convert to dgl_graph
    Args:
    sdf_file: path of sdf file
    self_loop: Whetaher to add self loop
    Returns:
    g: DGLGraph
    l: related labels
    """
    sdf = open(str(sdf_file)).read()
    mol = Chem.MolFromMolBlock(sdf, removeHs=False)

    g = dgl.DGLGraph()

    # add nodes
    num_atoms = mol.GetNumAtoms()
    atom_feats = self.alchemy_nodes(mol)
    g.add_nodes(num=num_atoms, data=atom_feats)

    # add edges
    # The model we were interested assumes a complete graph.
    # If this is not the case, do the code below instead
    #
    # for bond in mol.GetBonds():
    # u = bond.GetBeginAtomIdx()
    # v = bond.GetEndAtomIdx()
    if self_loop:
    g.add_edges(
    [i for i in range(num_atoms) for j in range(num_atoms)],
    [j for i in range(num_atoms) for j in range(num_atoms)])
    else:
    g.add_edges(
    [i for i in range(num_atoms) for j in range(num_atoms - 1)], [
    j for i in range(num_atoms)
    for j in range(num_atoms) if i != j
    ])

    bond_feats = self.alchemy_edges(mol, self_loop)
    g.edata.update(bond_feats)

    # for val/test set, labels are molecule ID
    l = torch.FloatTensor(self.target.loc[int(sdf_file.stem)].tolist()) \
    if self.mode == 'dev' else torch.LongTensor([int(sdf_file.stem)])
    return (g, l)

    def __init__(self, mode='dev', num_processes=1, transform=None):
    assert mode in ['dev', 'valid',
    'test'], "mode should be dev/valid/test"
    self.mode = mode
    self.transform = transform
    self.file_dir = pathlib.Path('./Alchemy_data', mode)
    self.zip_file_path = pathlib.Path('./Alchemy_data', '%s.zip' % mode)
    download(_urls['Alchemy'] + "%s.zip" % mode,
    path=str(self.zip_file_path))
    if not os.path.exists(str(self.file_dir)):
    archive = zipfile.ZipFile(self.zip_file_path)
    archive.extractall('./Alchemy_data')
    archive.close()

    self._load(num_processes)

    def _load(self, num_processes):
    if self.mode == 'dev':
    target_file = pathlib.Path(self.file_dir, "dev_target.csv")
    self.target = pd.read_csv(target_file,
    index_col=0,
    usecols=[
    'gdb_idx',
    ] +
    ['property_%d' % x for x in range(12)])
    self.target = self.target[['property_%d' % x for x in range(12)]]

    sdf_dir = pathlib.Path(self.file_dir, "sdf")
    self.graphs, self.labels = [], []
    if num_processes == 1:
    for sdf_file in sdf_dir.glob("**/*.sdf"):
    result = self.sdf_to_dgl(sdf_file)
    if result is None:
    continue
    self.graphs.append(result[0])
    self.labels.append(result[1])
    else:
    with Pool(processes=num_processes) as pool:
    all_result = pool.map(self.sdf_to_dgl, list(sdf_dir.glob("**/*.sdf")))
    for result in all_result:
    self.graphs.append(result[0])
    self.labels.append(result[1])

    self.normalize()
    print(len(self.graphs), "loaded!")

    def normalize(self, mean=None, std=None):
    labels = np.array([i.numpy() for i in self.labels])
    if mean is None:
    mean = np.mean(labels, axis=0)
    if std is None:
    std = np.std(labels, axis=0)
    self.mean = mean
    self.std = std

    def __len__(self):
    return len(self.graphs)

    def __getitem__(self, idx):
    g, l = self.graphs[idx], self.labels[idx]
    if self.transform:
    g = self.transform(g)
    return g, l

    if __name__ == '__main__':
    import argparse
    import datetime
    import time

    parser = argparse.ArgumentParser()
    parser.add_argument('-np', '--num-processes', type=int, default=2,
    help='Number of subprocesses to use for graph construction and featurization')
    args = parser.parse_args()

    t1 = time.time()
    alchemy_dataset = TencentAlchemyDataset()
    t2 = time.time()

    alchemy_dataset = TencentAlchemyDataset(num_processes=args.num_processes)
    print('It took {} with single process.'.format(datetime.timedelta(seconds=t2 - t1)))
    print('It took {} with {:d} processes.'.format(datetime.timedelta(seconds=t3 - t2),
    args.num_processes))