Last active
August 22, 2019 03:42
-
-
Save mufeili/4b103b809a8bd1241daf6fae6c1631ac to your computer and use it in GitHub Desktop.
Benchmark for multiprocess construction
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # -*- coding:utf-8 -*- | |
| """Example dataloader of Tencent Alchemy Dataset | |
| https://alchemy.tencent.com/ | |
| """ | |
| import dgl | |
| import numpy as np | |
| import os | |
| import os.path as osp | |
| import pandas as pd | |
| import pathlib | |
| import torch | |
| import zipfile | |
| from collections import defaultdict | |
| from dgl.data.utils import download | |
| from multiprocessing import Pool | |
| from rdkit import Chem | |
| from rdkit.Chem import ChemicalFeatures | |
| from rdkit import RDConfig | |
| from torch.utils.data import Dataset | |
| from torch.utils.data import DataLoader | |
| _urls = {'Alchemy': 'https://alchemy.tencent.com/data/'} | |
| class AlchemyBatcher: | |
| def __init__(self, graph=None, label=None): | |
| self.graph = graph | |
| self.label = label | |
| def alchemy_nodes(mol): | |
| """Featurization for all atoms in a molecule. The atom indices | |
| will be preserved. | |
| Args: | |
| mol : rdkit.Chem.rdchem.Mol | |
| RDKit molecule object | |
| Returns | |
| atom_feats_dict : dict | |
| Dictionary for atom features | |
| """ | |
| atom_feats_dict = defaultdict(list) | |
| is_donor = defaultdict(int) | |
| is_acceptor = defaultdict(int) | |
| fdef_name = osp.join(RDConfig.RDDataDir, 'BaseFeatures.fdef') | |
| mol_featurizer = ChemicalFeatures.BuildFeatureFactory(fdef_name) | |
| mol_feats = mol_featurizer.GetFeaturesForMol(mol) | |
| mol_conformers = mol.GetConformers() | |
| assert len(mol_conformers) == 1 | |
| geom = mol_conformers[0].GetPositions() | |
| for i in range(len(mol_feats)): | |
| if mol_feats[i].GetFamily() == 'Donor': | |
| node_list = mol_feats[i].GetAtomIds() | |
| for u in node_list: | |
| is_donor[u] = 1 | |
| elif mol_feats[i].GetFamily() == 'Acceptor': | |
| node_list = mol_feats[i].GetAtomIds() | |
| for u in node_list: | |
| is_acceptor[u] = 1 | |
| num_atoms = mol.GetNumAtoms() | |
| for u in range(num_atoms): | |
| atom = mol.GetAtomWithIdx(u) | |
| symbol = atom.GetSymbol() | |
| atom_type = atom.GetAtomicNum() | |
| aromatic = atom.GetIsAromatic() | |
| hybridization = atom.GetHybridization() | |
| num_h = atom.GetTotalNumHs() | |
| atom_feats_dict['pos'].append(torch.FloatTensor(geom[u])) | |
| atom_feats_dict['node_type'].append(atom_type) | |
| h_u = [] | |
| h_u += [ | |
| int(symbol == x) for x in ['H', 'C', 'N', 'O', 'F', 'S', 'Cl'] | |
| ] | |
| h_u.append(atom_type) | |
| h_u.append(is_acceptor[u]) | |
| h_u.append(is_donor[u]) | |
| h_u.append(int(aromatic)) | |
| h_u += [ | |
| int(hybridization == x) | |
| for x in (Chem.rdchem.HybridizationType.SP, | |
| Chem.rdchem.HybridizationType.SP2, | |
| Chem.rdchem.HybridizationType.SP3) | |
| ] | |
| h_u.append(num_h) | |
| atom_feats_dict['n_feat'].append(torch.FloatTensor(h_u)) | |
| atom_feats_dict['n_feat'] = torch.stack(atom_feats_dict['n_feat'], dim=0) | |
| atom_feats_dict['pos'] = torch.stack(atom_feats_dict['pos'], dim=0) | |
| atom_feats_dict['node_type'] = torch.LongTensor(atom_feats_dict['node_type']) | |
| return atom_feats_dict | |
| def alchemy_edges(mol, self_loop=True): | |
| """Featurization for all bonds in a molecule. The bond indices | |
| will be preserved. | |
| Args: | |
| mol : rdkit.Chem.rdchem.Mol | |
| RDKit molecule object | |
| Returns | |
| bond_feats_dict : dict | |
| Dictionary for bond features | |
| """ | |
| bond_feats_dict = defaultdict(list) | |
| mol_conformers = mol.GetConformers() | |
| assert len(mol_conformers) == 1 | |
| geom = mol_conformers[0].GetPositions() | |
| num_atoms = mol.GetNumAtoms() | |
| for u in range(num_atoms): | |
| for v in range(num_atoms): | |
| if u == v and not self_loop: | |
| continue | |
| e_uv = mol.GetBondBetweenAtoms(u, v) | |
| if e_uv is None: | |
| bond_type = None | |
| else: | |
| bond_type = e_uv.GetBondType() | |
| bond_feats_dict['e_feat'].append([ | |
| float(bond_type == x) | |
| for x in (Chem.rdchem.BondType.SINGLE, | |
| Chem.rdchem.BondType.DOUBLE, | |
| Chem.rdchem.BondType.TRIPLE, | |
| Chem.rdchem.BondType.AROMATIC, None) | |
| ]) | |
| bond_feats_dict['distance'].append( | |
| np.linalg.norm(geom[u] - geom[v])) | |
| bond_feats_dict['e_feat'] = torch.FloatTensor( | |
| bond_feats_dict['e_feat']) | |
| bond_feats_dict['distance'] = torch.FloatTensor( | |
| bond_feats_dict['distance']).reshape(-1, 1) | |
| return bond_feats_dict | |
| def mol_to_dgl(mol, self_loop=False): | |
| """ | |
| Read sdf file and convert to dgl_graph | |
| Args: | |
| mol: Chem.rdchem.Mol | |
| self_loop: Whether to add self loop | |
| Returns: | |
| g: DGLGraph | |
| """ | |
| g = dgl.DGLGraph() | |
| # add nodes | |
| num_atoms = mol.GetNumAtoms() | |
| atom_feats = alchemy_nodes(mol) | |
| g.add_nodes(num=num_atoms, data=atom_feats) | |
| # add edges | |
| # The model we were interested assumes a complete graph. | |
| # If this is not the case, do the code below instead | |
| # | |
| # for bond in mol.GetBonds(): | |
| # u = bond.GetBeginAtomIdx() | |
| # v = bond.GetEndAtomIdx() | |
| if self_loop: | |
| g.add_edges( | |
| [i for i in range(num_atoms) for j in range(num_atoms)], | |
| [j for i in range(num_atoms) for j in range(num_atoms)]) | |
| else: | |
| g.add_edges( | |
| [i for i in range(num_atoms) for j in range(num_atoms - 1)], [ | |
| j for i in range(num_atoms) | |
| for j in range(num_atoms) if i != j | |
| ]) | |
| bond_feats = alchemy_edges(mol, self_loop) | |
| g.edata.update(bond_feats) | |
| return g | |
| def batch_mol_to_dgl(batch_mols, num_processes): | |
| with Pool(processes=num_processes) as pool: | |
| results = pool.map(mol_to_dgl, batch_mols, | |
| chunksize=(len(batch_mols) // num_processes + 1)) | |
| return results | |
| def batcher(): | |
| def batcher_dev(batch): | |
| graphs, labels = zip(*batch) | |
| batch_graphs = dgl.batch(graphs) | |
| labels = torch.stack(labels, 0) | |
| return AlchemyBatcher(graph=batch_graphs, label=labels) | |
| return batcher_dev | |
| class TencentAlchemyDataset(Dataset): | |
| fdef_name = osp.join(RDConfig.RDDataDir, 'BaseFeatures.fdef') | |
| chem_feature_factory = ChemicalFeatures.BuildFeatureFactory(fdef_name) | |
| def _get_label(self, sdf_file): | |
| """Get molecule label. | |
| Args: | |
| sdf_file: path of sdf file | |
| Returns | |
| l : molecule label | |
| """ | |
| # for val/test set, labels are molecule ID | |
| l = torch.FloatTensor(self.target.loc[int(sdf_file.stem)].tolist()) \ | |
| if self.mode == 'dev' else torch.LongTensor([int(sdf_file.stem)]) | |
| return l | |
| def sdf_to_mol(self, sdf_file): | |
| """Read sdf file and convert it to a rdkit molecule object. | |
| Args: | |
| sdf_file: path of sdf file | |
| Returns: | |
| mol: Chem.rdchem.Mol | |
| """ | |
| with open(str(sdf_file)) as f: | |
| sdf = f.read() | |
| mol = Chem.MolFromMolBlock(sdf, removeHs=False) | |
| return mol | |
| def __init__(self, mode='dev', num_processes=1, transform=None, **kwargs): | |
| assert mode in ['dev', 'valid', | |
| 'test'], "mode should be dev/valid/test" | |
| self.mode = mode | |
| self.transform = transform | |
| self.file_dir = pathlib.Path('./Alchemy_data', mode) | |
| self.zip_file_path = pathlib.Path('./Alchemy_data', '%s.zip' % mode) | |
| download(_urls['Alchemy'] + "%s.zip" % mode, | |
| path=str(self.zip_file_path)) | |
| if not os.path.exists(str(self.file_dir)): | |
| archive = zipfile.ZipFile(self.zip_file_path) | |
| archive.extractall('./Alchemy_data') | |
| archive.close() | |
| self._load(num_processes, **kwargs) | |
| def _load(self, num_processes, **kwargs): | |
| if self.mode == 'dev': | |
| target_file = pathlib.Path(self.file_dir, "dev_target.csv") | |
| self.target = pd.read_csv(target_file, | |
| index_col=0, | |
| usecols=[ | |
| 'gdb_idx', | |
| ] + | |
| ['property_%d' % x for x in range(12)]) | |
| self.target = self.target[['property_%d' % x for x in range(12)]] | |
| sdf_dir = pathlib.Path(self.file_dir, "sdf") | |
| self.graphs, self.labels = [], [] | |
| if num_processes == 1: | |
| for sdf_file in sdf_dir.glob("**/*.sdf"): | |
| mol = self.sdf_to_mol(sdf_file) | |
| self.graphs.append(self.mol_to_dgl(mol)) | |
| self.labels.append(self._get_label(sdf_file)) | |
| else: | |
| batch_size = kwargs.get('batch_size') | |
| sdf_file_list = list(sdf_dir.glob("**/*.sdf")) | |
| mols = [] | |
| for sdf_file in sdf_file_list: | |
| mols.append(self.sdf_to_mol(sdf_file)) | |
| self.labels.append(self._get_label(sdf_file)) | |
| batch_id = 0 | |
| while batch_id * batch_size < len(sdf_file_list): | |
| batch_mols = mols[batch_id * batch_size:(batch_id + 1) * batch_size] | |
| self.graphs.extend(batch_mol_to_dgl(batch_mols, num_processes)) | |
| batch_id += 1 | |
| self.normalize() | |
| print(len(self.graphs), "loaded!") | |
| def normalize(self, mean=None, std=None): | |
| labels = np.array([i.numpy() for i in self.labels]) | |
| if mean is None: | |
| mean = np.mean(labels, axis=0) | |
| if std is None: | |
| std = np.std(labels, axis=0) | |
| self.mean = mean | |
| self.std = std | |
| def __len__(self): | |
| return len(self.graphs) | |
| def __getitem__(self, idx): | |
| g, l = self.graphs[idx], self.labels[idx] | |
| if self.transform: | |
| g = self.transform(g) | |
| return g, l | |
| if __name__ == '__main__': | |
| import argparse | |
| import datetime | |
| import time | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('-np', '--num-processes', type=int, default=2, | |
| help='Number of subprocesses to use for graph construction and featurization') | |
| parser.add_argument('-b', '--batch-size', type=int, default=2, | |
| help='Batch size for processing') | |
| args = parser.parse_args() | |
| t1 = time.time() | |
| alchemy_dataset = TencentAlchemyDataset(num_processes=args.num_processes, | |
| batch_size=args.batch_size) | |
| t2 = time.time() | |
| print('It took {} with {:d} processes.'.format(datetime.timedelta(seconds=t2 - t1), | |
| args.num_processes)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment