-
Notifications
You must be signed in to change notification settings - Fork 1
/
preprocess.py
115 lines (92 loc) · 3.63 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#! /usr/bin/env python
# -*- coding: utf-8 -*-
"""
Molecule preprocessing functions
"""
import click
import numpy as np
import pandas as pd
from rdkit.Chem import CanonSmiles
from dataset import tokenizer
def keep_longest(smls, return_salt=False):
"""function to keep the longest fragment of a smiles string after fragmentation by splitting at '.'
:param smls: {list} list of smiles strings
:param return_salt: {bool} whether to return the stripped salts as well
:return: {list} list of longest fragments
"""
parents = []
salts = []
if isinstance(smls, str):
smls = [smls]
for s in smls:
if "." in s:
f = s.split(".")
lengths = [len(m) for m in f]
n = int(np.argmax(lengths))
parents.append(f[n])
f.pop(n)
salts.append(f)
else:
parents.append(s)
salts.append([""])
if return_salt:
return parents, salts
return parents
def harmonize_sc(mols):
"""harmonize the sidechains of given SMILES strings to a normalized format
:param mols: {list} molecules as SMILES string
:return: {list} harmonized molecules as SMILES string
"""
out = list()
for mol in mols:
# TODO: add more problematic sidechain representation that occur
pairs = [
("[N](=O)[O-]", "[N+](=O)[O-]"),
("[O-][N](=O)", "[O-][N+](=O)"),
] # (before, after)
for b, a in pairs:
mol = mol.replace(b, a)
out.append(mol)
return out
def batchify(iterable, batch_size):
for ndx in range(0, len(iterable), batch_size):
batch = iterable[ndx : min(ndx + batch_size, len(iterable))]
yield batch
def preprocess_smiles_file(filename, smls_col, delimiter, max_len, batch_size):
def canon(s):
try:
o = CanonSmiles(s)
return o if len(o) <= max_len else None
except Exception:
return None
_, t2i = tokenizer()
if filename.endswith(".gz"):
data = pd.read_csv(filename, delimiter=delimiter, compression="gzip").rename(columns={smls_col: "SMILES"})
else:
data = pd.read_csv(filename, delimiter=delimiter).rename(columns={smls_col: "SMILES"})
print(f"{len(data)} SMILES strings read")
print("Keeping longest fragment...")
smls = keep_longest(data.SMILES.values)
del data # cleanup to save memory
print("Harmonizing side chains...")
smls = harmonize_sc(smls)
print("Checking SMILES validity...")
out = []
for batch in batchify(smls, batch_size):
tmp = [o for o in [canon(s) for s in batch] if o is not None]
out.extend([s for s in tmp if all([c in t2i.keys() for c in s])])
uniq = list(set(out))
print(f"{len(uniq)} valid unique SMILES strings obtained")
return pd.DataFrame({"SMILES": uniq})
@click.command()
@click.argument("filename")
@click.option("-d", "--delimiter", default="\t", help="Column delimiter of input file.")
@click.option("-c", "--smls_col", default="SMILES", help="Name of column that contains SMILES.")
@click.option("-l", "--max_len", default=150, help="Maximum length of SMILES string in characters.")
@click.option("-b", "--batch_size", default=100000, help="Batch size used to chunck up SMILES list for processing.")
def main(filename, smls_col, delimiter, max_len, batch_size):
data = preprocess_smiles_file(filename, smls_col, delimiter, max_len, batch_size)
data.to_csv(f"{filename[:-4]}_proc.txt.gz", sep="\t", index=False, compression="gzip")
print(f"preprocessing completed! Saved {len(data)} SMILES to {filename[:-4]}_proc.txt.gz")
if __name__ == "__main__":
main()