Skip to content

Instantly share code, notes, and snippets.

@mufeili
Last active August 22, 2019 03:42
Show Gist options
  • Select an option

  • Save mufeili/4b103b809a8bd1241daf6fae6c1631ac to your computer and use it in GitHub Desktop.

Select an option

Save mufeili/4b103b809a8bd1241daf6fae6c1631ac to your computer and use it in GitHub Desktop.
Benchmark for multiprocess construction
# -*- coding:utf-8 -*-
"""Example dataloader of Tencent Alchemy Dataset
https://alchemy.tencent.com/
"""
import dgl
import numpy as np
import os
import os.path as osp
import pandas as pd
import pathlib
import torch
import zipfile
from collections import defaultdict
from dgl.data.utils import download
from multiprocessing import Pool
from rdkit import Chem
from rdkit.Chem import ChemicalFeatures
from rdkit import RDConfig
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
_urls = {'Alchemy': 'https://alchemy.tencent.com/data/'}
class AlchemyBatcher:
def __init__(self, graph=None, label=None):
self.graph = graph
self.label = label
def alchemy_nodes(mol):
"""Featurization for all atoms in a molecule. The atom indices
will be preserved.
Args:
mol : rdkit.Chem.rdchem.Mol
RDKit molecule object
Returns
atom_feats_dict : dict
Dictionary for atom features
"""
atom_feats_dict = defaultdict(list)
is_donor = defaultdict(int)
is_acceptor = defaultdict(int)
fdef_name = osp.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
mol_featurizer = ChemicalFeatures.BuildFeatureFactory(fdef_name)
mol_feats = mol_featurizer.GetFeaturesForMol(mol)
mol_conformers = mol.GetConformers()
assert len(mol_conformers) == 1
geom = mol_conformers[0].GetPositions()
for i in range(len(mol_feats)):
if mol_feats[i].GetFamily() == 'Donor':
node_list = mol_feats[i].GetAtomIds()
for u in node_list:
is_donor[u] = 1
elif mol_feats[i].GetFamily() == 'Acceptor':
node_list = mol_feats[i].GetAtomIds()
for u in node_list:
is_acceptor[u] = 1
num_atoms = mol.GetNumAtoms()
for u in range(num_atoms):
atom = mol.GetAtomWithIdx(u)
symbol = atom.GetSymbol()
atom_type = atom.GetAtomicNum()
aromatic = atom.GetIsAromatic()
hybridization = atom.GetHybridization()
num_h = atom.GetTotalNumHs()
atom_feats_dict['pos'].append(torch.FloatTensor(geom[u]))
atom_feats_dict['node_type'].append(atom_type)
h_u = []
h_u += [
int(symbol == x) for x in ['H', 'C', 'N', 'O', 'F', 'S', 'Cl']
]
h_u.append(atom_type)
h_u.append(is_acceptor[u])
h_u.append(is_donor[u])
h_u.append(int(aromatic))
h_u += [
int(hybridization == x)
for x in (Chem.rdchem.HybridizationType.SP,
Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3)
]
h_u.append(num_h)
atom_feats_dict['n_feat'].append(torch.FloatTensor(h_u))
atom_feats_dict['n_feat'] = torch.stack(atom_feats_dict['n_feat'], dim=0)
atom_feats_dict['pos'] = torch.stack(atom_feats_dict['pos'], dim=0)
atom_feats_dict['node_type'] = torch.LongTensor(atom_feats_dict['node_type'])
return atom_feats_dict
def alchemy_edges(mol, self_loop=True):
"""Featurization for all bonds in a molecule. The bond indices
will be preserved.
Args:
mol : rdkit.Chem.rdchem.Mol
RDKit molecule object
Returns
bond_feats_dict : dict
Dictionary for bond features
"""
bond_feats_dict = defaultdict(list)
mol_conformers = mol.GetConformers()
assert len(mol_conformers) == 1
geom = mol_conformers[0].GetPositions()
num_atoms = mol.GetNumAtoms()
for u in range(num_atoms):
for v in range(num_atoms):
if u == v and not self_loop:
continue
e_uv = mol.GetBondBetweenAtoms(u, v)
if e_uv is None:
bond_type = None
else:
bond_type = e_uv.GetBondType()
bond_feats_dict['e_feat'].append([
float(bond_type == x)
for x in (Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE,
Chem.rdchem.BondType.AROMATIC, None)
])
bond_feats_dict['distance'].append(
np.linalg.norm(geom[u] - geom[v]))
bond_feats_dict['e_feat'] = torch.FloatTensor(
bond_feats_dict['e_feat'])
bond_feats_dict['distance'] = torch.FloatTensor(
bond_feats_dict['distance']).reshape(-1, 1)
return bond_feats_dict
def mol_to_dgl(mol, self_loop=False):
"""
Read sdf file and convert to dgl_graph
Args:
mol: Chem.rdchem.Mol
self_loop: Whether to add self loop
Returns:
g: DGLGraph
"""
g = dgl.DGLGraph()
# add nodes
num_atoms = mol.GetNumAtoms()
atom_feats = alchemy_nodes(mol)
g.add_nodes(num=num_atoms, data=atom_feats)
# add edges
# The model we were interested assumes a complete graph.
# If this is not the case, do the code below instead
#
# for bond in mol.GetBonds():
# u = bond.GetBeginAtomIdx()
# v = bond.GetEndAtomIdx()
if self_loop:
g.add_edges(
[i for i in range(num_atoms) for j in range(num_atoms)],
[j for i in range(num_atoms) for j in range(num_atoms)])
else:
g.add_edges(
[i for i in range(num_atoms) for j in range(num_atoms - 1)], [
j for i in range(num_atoms)
for j in range(num_atoms) if i != j
])
bond_feats = alchemy_edges(mol, self_loop)
g.edata.update(bond_feats)
return g
def batch_mol_to_dgl(batch_mols, num_processes):
with Pool(processes=num_processes) as pool:
results = pool.map(mol_to_dgl, batch_mols,
chunksize=(len(batch_mols) // num_processes + 1))
return results
def batcher():
def batcher_dev(batch):
graphs, labels = zip(*batch)
batch_graphs = dgl.batch(graphs)
labels = torch.stack(labels, 0)
return AlchemyBatcher(graph=batch_graphs, label=labels)
return batcher_dev
class TencentAlchemyDataset(Dataset):
fdef_name = osp.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
chem_feature_factory = ChemicalFeatures.BuildFeatureFactory(fdef_name)
def _get_label(self, sdf_file):
"""Get molecule label.
Args:
sdf_file: path of sdf file
Returns
l : molecule label
"""
# for val/test set, labels are molecule ID
l = torch.FloatTensor(self.target.loc[int(sdf_file.stem)].tolist()) \
if self.mode == 'dev' else torch.LongTensor([int(sdf_file.stem)])
return l
def sdf_to_mol(self, sdf_file):
"""Read sdf file and convert it to a rdkit molecule object.
Args:
sdf_file: path of sdf file
Returns:
mol: Chem.rdchem.Mol
"""
with open(str(sdf_file)) as f:
sdf = f.read()
mol = Chem.MolFromMolBlock(sdf, removeHs=False)
return mol
def __init__(self, mode='dev', num_processes=1, transform=None, **kwargs):
assert mode in ['dev', 'valid',
'test'], "mode should be dev/valid/test"
self.mode = mode
self.transform = transform
self.file_dir = pathlib.Path('./Alchemy_data', mode)
self.zip_file_path = pathlib.Path('./Alchemy_data', '%s.zip' % mode)
download(_urls['Alchemy'] + "%s.zip" % mode,
path=str(self.zip_file_path))
if not os.path.exists(str(self.file_dir)):
archive = zipfile.ZipFile(self.zip_file_path)
archive.extractall('./Alchemy_data')
archive.close()
self._load(num_processes, **kwargs)
def _load(self, num_processes, **kwargs):
if self.mode == 'dev':
target_file = pathlib.Path(self.file_dir, "dev_target.csv")
self.target = pd.read_csv(target_file,
index_col=0,
usecols=[
'gdb_idx',
] +
['property_%d' % x for x in range(12)])
self.target = self.target[['property_%d' % x for x in range(12)]]
sdf_dir = pathlib.Path(self.file_dir, "sdf")
self.graphs, self.labels = [], []
if num_processes == 1:
for sdf_file in sdf_dir.glob("**/*.sdf"):
mol = self.sdf_to_mol(sdf_file)
self.graphs.append(self.mol_to_dgl(mol))
self.labels.append(self._get_label(sdf_file))
else:
batch_size = kwargs.get('batch_size')
sdf_file_list = list(sdf_dir.glob("**/*.sdf"))
mols = []
for sdf_file in sdf_file_list:
mols.append(self.sdf_to_mol(sdf_file))
self.labels.append(self._get_label(sdf_file))
batch_id = 0
while batch_id * batch_size < len(sdf_file_list):
batch_mols = mols[batch_id * batch_size:(batch_id + 1) * batch_size]
self.graphs.extend(batch_mol_to_dgl(batch_mols, num_processes))
batch_id += 1
self.normalize()
print(len(self.graphs), "loaded!")
def normalize(self, mean=None, std=None):
labels = np.array([i.numpy() for i in self.labels])
if mean is None:
mean = np.mean(labels, axis=0)
if std is None:
std = np.std(labels, axis=0)
self.mean = mean
self.std = std
def __len__(self):
return len(self.graphs)
def __getitem__(self, idx):
g, l = self.graphs[idx], self.labels[idx]
if self.transform:
g = self.transform(g)
return g, l
if __name__ == '__main__':
import argparse
import datetime
import time
parser = argparse.ArgumentParser()
parser.add_argument('-np', '--num-processes', type=int, default=2,
help='Number of subprocesses to use for graph construction and featurization')
parser.add_argument('-b', '--batch-size', type=int, default=2,
help='Batch size for processing')
args = parser.parse_args()
t1 = time.time()
alchemy_dataset = TencentAlchemyDataset(num_processes=args.num_processes,
batch_size=args.batch_size)
t2 = time.time()
print('It took {} with {:d} processes.'.format(datetime.timedelta(seconds=t2 - t1),
args.num_processes))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment