antoniaebner commited on
Commit
81226cb
·
1 Parent(s): 7f6d1d6

add RF framework

Browse files
Files changed (6) hide show
  1. data.py +158 -0
  2. model.py +60 -0
  3. predict.py +42 -0
  4. requirements.txt +7 -0
  5. train.py +79 -0
  6. utils.py +441 -0
data.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pipeline taken from https://huggingface.co/spaces/ml-jku/mhnfs/blob/main/src/data_preprocessing/create_descriptors.py
2
+
3
+ """
4
+ This files includes a the data processing for Tox21.
5
+ As an input it takes a list of SMILES and it outputs a nested dictionary with
6
+ SMILES and target names as keys.
7
+ """
8
+
9
+ import os
10
+ from typing import List
11
+
12
+ import numpy as np
13
+
14
+ from sklearn.preprocessing import StandardScaler
15
+ from statsmodels.distributions.empirical_distribution import ECDF
16
+
17
+ from rdkit import Chem, DataStructs
18
+ from rdkit.Chem import Descriptors, rdFingerprintGenerator
19
+ from rdkit.Chem.rdchem import Mol
20
+
21
+ from utils import USED_200_DESCR, Standardizer, load_pickle, write_pickle
22
+
23
+
24
+ def preprocess_molecules(
25
+ smiles_list: list[str],
26
+ load_ecdf_path: str = "",
27
+ load_scaler_path: str = "",
28
+ save_ecdf_path: str = "",
29
+ save_scaler_path: str = "",
30
+ ) -> list[int]:
31
+ """preprocess a list of molecules"""
32
+ assert not (
33
+ load_ecdf_path and save_ecdf_path
34
+ ), "Cannot pass 'load_ecdf_path' and 'save_ecdf_path' simultaneously"
35
+ assert not (
36
+ load_scaler_path and save_scaler_path
37
+ ), "Cannot pass 'load_scaler_path' and 'save_scaler_path' simultaneously"
38
+
39
+ ecdfs = (
40
+ load_pickle(load_ecdf_path)
41
+ if load_ecdf_path and os.path.exists(load_ecdf_path)
42
+ else None
43
+ )
44
+ scaler = (
45
+ load_pickle(load_scaler_path)
46
+ if load_scaler_path and os.path.exists(load_scaler_path)
47
+ else None
48
+ )
49
+
50
+ # Create cleanded rdkit mol objects
51
+ mols, removed_idxs = create_cleaned_mol_objects(smiles_list)
52
+ print("Cleaned molecules")
53
+
54
+ # Create fingerprints and descriptors
55
+ ecfps = create_ecfp_fps(mols)
56
+ print("Created ECFP fingerprints")
57
+ rdkit_descrs = create_rdkit_descriptors(mols)
58
+ print("Created RDKit descriptors")
59
+
60
+ # Create and save ecdfs
61
+ if ecdfs is None:
62
+ print("Create ECDFs")
63
+ ecdfs = []
64
+ for column in range(rdkit_descrs.shape[1]):
65
+ raw_values = rdkit_descrs[:, column].reshape(-1)
66
+ ecdfs.append(ECDF(raw_values))
67
+ if save_ecdf_path:
68
+ write_pickle(save_ecdf_path, ecdfs)
69
+ print(f"Saved ECDFs under {save_ecdf_path}")
70
+
71
+ # Create quantils
72
+ rdkit_descr_quantils = create_quantils(rdkit_descrs, ecdfs)
73
+ print("Created quantiles of RDKit descriptors")
74
+
75
+ # Concatenate features
76
+ raw_features = np.concatenate((ecfps, rdkit_descr_quantils), axis=1)
77
+
78
+ if scaler is None:
79
+ scaler = StandardScaler()
80
+ scaler.fit(raw_features)
81
+ print("Fitted the StandardScaler")
82
+ if save_scaler_path:
83
+ write_pickle(save_scaler_path, scaler)
84
+ print(f"Saved the StandardScaler under {save_scaler_path}")
85
+
86
+ # Normalize feature vectors
87
+ normalized_features = scaler.transform(raw_features)
88
+ print("Normalized the molecule features")
89
+
90
+ return normalized_features, removed_idxs
91
+
92
+
93
+ def create_cleaned_mol_objects(smiles: List[str]) -> List[Mol]:
94
+ """
95
+ This function creates cleaned RDKit mol objects from a list of SMILES.
96
+ """
97
+ sm = Standardizer(canon_taut=True)
98
+
99
+ removed_idxs = list()
100
+ mols = list()
101
+ for i, smile in enumerate(smiles):
102
+ mol = Chem.MolFromSmiles(smile)
103
+ standardized_mol, _ = sm.standardize_mol(mol)
104
+ if standardized_mol is None:
105
+ removed_idxs.append(i)
106
+ continue
107
+ can_mol = Chem.MolFromSmiles(Chem.MolToSmiles(standardized_mol))
108
+ mols.append(can_mol)
109
+ return mols, removed_idxs
110
+
111
+
112
+ def create_ecfp_fps(mols: List[Mol]) -> np.ndarray:
113
+ """
114
+ This function ECFP fingerprints for a list of molecules.
115
+ """
116
+ ecfps = list()
117
+
118
+ for mol in mols:
119
+ fp_sparse_vec = rdFingerprintGenerator.GetCountFPs(
120
+ [mol], fpType=rdFingerprintGenerator.MorganFP
121
+ )[0]
122
+ fp = np.zeros((0,), np.int8)
123
+ DataStructs.ConvertToNumpyArray(fp_sparse_vec, fp)
124
+
125
+ ecfps.append(fp)
126
+
127
+ return np.array(ecfps)
128
+
129
+
130
+ def create_rdkit_descriptors(mols: List[Mol]) -> np.ndarray:
131
+ """
132
+ This function creates RDKit descriptors for a list of molecules.
133
+ """
134
+ rdkit_descriptors = list()
135
+
136
+ for mol in mols:
137
+ descrs = []
138
+ for _, descr_calc_fn in Descriptors._descList:
139
+ descrs.append(descr_calc_fn(mol))
140
+
141
+ descrs = np.array(descrs)
142
+ descrs = descrs[USED_200_DESCR]
143
+ rdkit_descriptors.append(descrs)
144
+
145
+ return np.array(rdkit_descriptors)
146
+
147
+
148
+ def create_quantils(raw_features: np.ndarray, ecdfs: list) -> np.ndarray:
149
+
150
+ quantils = np.zeros_like(raw_features)
151
+
152
+ for column in range(raw_features.shape[1]):
153
+ raw_values = raw_features[:, column].reshape(-1)
154
+ ecdf = ecdfs[column]
155
+ q = ecdf(raw_values)
156
+ quantils[:, column] = q
157
+
158
+ return quantils
model.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This files includes a RF model for Tox21.
3
+ As an input it takes a list of SMILES and it outputs a nested dictionary with
4
+ SMILES and target names as keys.
5
+ """
6
+
7
+ # ---------------------------------------------------------------------------------------
8
+ # Dependencies
9
+ import os
10
+ import joblib
11
+
12
+ import numpy as np
13
+ from sklearn.ensemble import RandomForestClassifier
14
+
15
+ from utils import TASKS
16
+
17
+
18
+ # ---------------------------------------------------------------------------------------
19
+ class Tox21RFClassifier:
20
+ """
21
+ A random forest classifier that assigns a toxicity score to a given SMILES string.
22
+ """
23
+
24
+ def __init__(self, seed: int = 42):
25
+ self.tasks = TASKS
26
+ self.model = {
27
+ task: RandomForestClassifier(n_estimators=1001, random_state=seed)
28
+ for task in self.tasks
29
+ }
30
+
31
+ def load_model(self, folder: str):
32
+ """
33
+ Loads the model from a given model checkpoint
34
+ """
35
+ self.model = {
36
+ task: joblib.load(os.path.join(folder, f"rf_{task}.joblib"))
37
+ for task in self.tasks
38
+ }
39
+
40
+ def save_model(self, folder: str):
41
+ """
42
+ Saves the model to a given folder
43
+ """
44
+ if not os.path.exists(folder):
45
+ os.makedirs(folder)
46
+
47
+ for task, model in self.model.items():
48
+ joblib.dump(model, os.path.join(folder, f"rf_{task}.joblib"))
49
+
50
+ def fit(self, task: str, input_features: np.ndarray, labels: np.ndarray) -> None:
51
+ assert task in self.tasks, f"Unknown task: {task}"
52
+ self.model[task].fit(input_features, labels)
53
+
54
+ def predict(self, task: str, features: np.ndarray) -> dict:
55
+ """
56
+ Predicts a given Tox21 targets for a given np.array of molecule features
57
+ """
58
+ assert task in self.tasks, f"Unknown task: {task}"
59
+ preds = self.model[task].predict_proba(features)
60
+ return preds[:, 1]
predict.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This files includes a predict function for the Tox21.
3
+ As an input it takes a list of SMILES and it outputs a nested dictionary with
4
+ SMILES and target names as keys.
5
+ """
6
+
7
+ # ---------------------------------------------------------------------------------------
8
+ # Dependencies
9
+ from typing import List
10
+ from collections import defaultdict
11
+
12
+ from data import preprocess_molecules
13
+ from model import Tox21RFClassifier
14
+
15
+ # ---------------------------------------------------------------------------------------
16
+
17
+
18
+ def predict(smiles_list: List[str]) -> dict:
19
+ """
20
+ Applies the classifier to a list of SMILES strings.
21
+ """
22
+ # preprocessing pipeline
23
+ features, removed_idxs = preprocess_molecules(
24
+ smiles_list,
25
+ load_ecdf_path="assets/ecdfs.pkl",
26
+ load_scaler_path="assets/scaler.pkl",
27
+ )
28
+
29
+ # setup model
30
+ model = Tox21RFClassifier(seed=42)
31
+ model.load_model("assets/model/")
32
+
33
+ # make predicitons
34
+ predictions = defaultdict(dict)
35
+
36
+ for i, smiles in enumerate(smiles_list):
37
+ for target in model.tasks:
38
+ predictions[smiles][target] = (
39
+ 0.0 if i in removed_idxs else model.predict(target, features[i])
40
+ )
41
+
42
+ return predictions
requirements.txt CHANGED
@@ -1,2 +1,9 @@
1
  fastapi
2
  uvicorn[standard]
 
 
 
 
 
 
 
 
1
  fastapi
2
  uvicorn[standard]
3
+ statsmodels
4
+ rdkit
5
+ numpy
6
+ scikit-learn
7
+ joblib
8
+ tabulate
9
+ datasets
train.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script for fitting and saving any preprocessing assets, as well as the fitted RandomForest model
3
+ """
4
+
5
+ import numpy as np
6
+
7
+ from tabulate import tabulate
8
+ from datasets import load_dataset
9
+ from sklearn.metrics import roc_auc_score
10
+
11
+ from data import preprocess_molecules
12
+ from model import Tox21RFClassifier
13
+ from utils import HF_TOKEN
14
+
15
+
16
+ def get_sample_mask(removed_idxs: list[int], labels: np.ndarray):
17
+ # mask out NaN labels and labels of removed idxs
18
+ task_mask = ~np.isnan(labels)
19
+ removed_mask = np.ones_like(labels, dtype=bool)
20
+ removed_mask[removed_idxs] = 0
21
+
22
+ feature_mask = task_mask[removed_mask]
23
+ label_mask = np.logical_and(task_mask, removed_mask)
24
+
25
+ return feature_mask, label_mask
26
+
27
+
28
+ def main():
29
+ # save preprocessing scaler and ecdf distributions
30
+ save_folder = "assets/model/"
31
+ ds = load_dataset("tschouis/tox21", token=HF_TOKEN)
32
+
33
+ print("Preprocess train molecules")
34
+ train_smiles = list(ds["train"]["smiles"])
35
+ train_features, train_removed_idxs = preprocess_molecules(
36
+ train_smiles,
37
+ save_ecdf_path="assets/ecdfs.pkl",
38
+ save_scaler_path="assets/scaler.pkl",
39
+ )
40
+
41
+ print("Preprocess validation molecules")
42
+ val_smiles = list(ds["validation"]["smiles"])
43
+ val_features, val_removed_idxs = preprocess_molecules(
44
+ val_smiles,
45
+ load_ecdf_path="assets/ecdfs.pkl",
46
+ load_scaler_path="assets/scaler.pkl",
47
+ )
48
+
49
+ model = Tox21RFClassifier(seed=42)
50
+ print("Start training.")
51
+ for task in model.tasks:
52
+ task_labels = ds["train"].to_pandas()[task].to_numpy()
53
+ feature_mask, label_mask = get_sample_mask(train_removed_idxs, task_labels)
54
+
55
+ print(f"Fit task {task} using {sum(label_mask)} samples")
56
+ model.fit(
57
+ task, train_features[feature_mask], task_labels[label_mask].astype(int)
58
+ )
59
+
60
+ print(f"Save model under {save_folder}")
61
+ # model.save_model(save_folder)
62
+
63
+ print("Evaluate model")
64
+ results = {}
65
+ for task in model.tasks:
66
+ task_labels = ds["validation"].to_pandas()[task].to_numpy()
67
+ feature_mask, label_mask = get_sample_mask(val_removed_idxs, task_labels)
68
+
69
+ pred = model.predict(task, val_features[feature_mask])
70
+ results[task] = [
71
+ roc_auc_score(y_true=task_labels[label_mask].astype(int), y_score=pred)
72
+ ]
73
+
74
+ print("Results:")
75
+ print(tabulate(results, headers="keys"))
76
+
77
+
78
+ if __name__ == "__main__":
79
+ main()
utils.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## These MolStandardizer classes are due to Paolo Tosco
2
+ ## It was taken from the FS-Mol github
3
+ ## (https://github.com/microsoft/FS-Mol/blob/main/fs_mol/preprocessing/utils/
4
+ ## standardizer.py)
5
+ ## They ensure that a sequence of standardization operations are applied
6
+ ## https://gist.github.com/ptosco/7e6b9ab9cc3e44ba0919060beaed198e
7
+
8
+ import os
9
+ import pickle
10
+
11
+ from rdkit import Chem
12
+ from rdkit.Chem.MolStandardize import rdMolStandardize
13
+
14
+ HF_TOKEN = os.environ.get("HF_TOKEN")
15
+
16
+ TASKS = [
17
+ "NR-AR",
18
+ "NR-AR-LBD",
19
+ "NR-AhR",
20
+ "NR-Aromatase",
21
+ "NR-ER",
22
+ "NR-ER-LBD",
23
+ "NR-PPAR-gamma",
24
+ "SR-ARE",
25
+ "SR-ATAD5",
26
+ "SR-HSE",
27
+ "SR-MMP",
28
+ "SR-p53",
29
+ ]
30
+
31
+ USED_200_DESCR = [
32
+ 0,
33
+ 1,
34
+ 2,
35
+ 3,
36
+ 4,
37
+ 5,
38
+ 6,
39
+ 7,
40
+ 8,
41
+ 9,
42
+ 10,
43
+ 11,
44
+ 12,
45
+ 13,
46
+ 14,
47
+ 15,
48
+ 16,
49
+ 25,
50
+ 26,
51
+ 27,
52
+ 28,
53
+ 29,
54
+ 30,
55
+ 31,
56
+ 32,
57
+ 33,
58
+ 34,
59
+ 35,
60
+ 36,
61
+ 37,
62
+ 38,
63
+ 39,
64
+ 40,
65
+ 41,
66
+ 42,
67
+ 43,
68
+ 44,
69
+ 45,
70
+ 46,
71
+ 47,
72
+ 48,
73
+ 49,
74
+ 50,
75
+ 51,
76
+ 52,
77
+ 53,
78
+ 54,
79
+ 55,
80
+ 56,
81
+ 57,
82
+ 58,
83
+ 59,
84
+ 60,
85
+ 61,
86
+ 62,
87
+ 63,
88
+ 64,
89
+ 65,
90
+ 66,
91
+ 67,
92
+ 68,
93
+ 69,
94
+ 70,
95
+ 71,
96
+ 72,
97
+ 73,
98
+ 74,
99
+ 75,
100
+ 76,
101
+ 77,
102
+ 78,
103
+ 79,
104
+ 80,
105
+ 81,
106
+ 82,
107
+ 83,
108
+ 84,
109
+ 85,
110
+ 86,
111
+ 87,
112
+ 88,
113
+ 89,
114
+ 90,
115
+ 91,
116
+ 92,
117
+ 93,
118
+ 94,
119
+ 95,
120
+ 96,
121
+ 97,
122
+ 98,
123
+ 99,
124
+ 100,
125
+ 101,
126
+ 102,
127
+ 103,
128
+ 104,
129
+ 105,
130
+ 106,
131
+ 107,
132
+ 108,
133
+ 109,
134
+ 110,
135
+ 111,
136
+ 112,
137
+ 113,
138
+ 114,
139
+ 115,
140
+ 116,
141
+ 117,
142
+ 118,
143
+ 119,
144
+ 120,
145
+ 121,
146
+ 122,
147
+ 123,
148
+ 124,
149
+ 125,
150
+ 126,
151
+ 127,
152
+ 128,
153
+ 129,
154
+ 130,
155
+ 131,
156
+ 132,
157
+ 133,
158
+ 134,
159
+ 135,
160
+ 136,
161
+ 137,
162
+ 138,
163
+ 139,
164
+ 140,
165
+ 141,
166
+ 142,
167
+ 143,
168
+ 144,
169
+ 145,
170
+ 146,
171
+ 147,
172
+ 148,
173
+ 149,
174
+ 150,
175
+ 151,
176
+ 152,
177
+ 153,
178
+ 154,
179
+ 155,
180
+ 156,
181
+ 157,
182
+ 158,
183
+ 159,
184
+ 160,
185
+ 161,
186
+ 162,
187
+ 163,
188
+ 164,
189
+ 165,
190
+ 166,
191
+ 167,
192
+ 168,
193
+ 169,
194
+ 170,
195
+ 171,
196
+ 172,
197
+ 173,
198
+ 174,
199
+ 175,
200
+ 176,
201
+ 177,
202
+ 178,
203
+ 179,
204
+ 180,
205
+ 181,
206
+ 182,
207
+ 183,
208
+ 184,
209
+ 185,
210
+ 186,
211
+ 187,
212
+ 188,
213
+ 189,
214
+ 190,
215
+ 191,
216
+ 192,
217
+ 193,
218
+ 194,
219
+ 195,
220
+ 196,
221
+ 197,
222
+ 198,
223
+ 199,
224
+ 200,
225
+ 201,
226
+ 202,
227
+ 203,
228
+ 204,
229
+ 205,
230
+ 206,
231
+ 207,
232
+ ]
233
+
234
+
235
+ class Standardizer:
236
+ """
237
+ Simple wrapper class around rdkit Standardizer.
238
+ """
239
+
240
+ DEFAULT_CANON_TAUT = False
241
+ DEFAULT_METAL_DISCONNECT = False
242
+ MAX_TAUTOMERS = 100
243
+ MAX_TRANSFORMS = 100
244
+ MAX_RESTARTS = 200
245
+ PREFER_ORGANIC = True
246
+
247
+ def __init__(
248
+ self,
249
+ metal_disconnect=None,
250
+ canon_taut=None,
251
+ ):
252
+ """
253
+ Constructor.
254
+ All parameters are optional.
255
+ :param metal_disconnect: if True, metallorganic complexes are
256
+ disconnected
257
+ :param canon_taut: if True, molecules are converted to their
258
+ canonical tautomer
259
+ """
260
+ super().__init__()
261
+ if metal_disconnect is None:
262
+ metal_disconnect = self.DEFAULT_METAL_DISCONNECT
263
+ if canon_taut is None:
264
+ canon_taut = self.DEFAULT_CANON_TAUT
265
+ self._canon_taut = canon_taut
266
+ self._metal_disconnect = metal_disconnect
267
+ self._taut_enumerator = None
268
+ self._uncharger = None
269
+ self._lfrag_chooser = None
270
+ self._metal_disconnector = None
271
+ self._normalizer = None
272
+ self._reionizer = None
273
+ self._params = None
274
+
275
+ @property
276
+ def params(self):
277
+ """Return the MolStandardize CleanupParameters."""
278
+ if self._params is None:
279
+ self._params = rdMolStandardize.CleanupParameters()
280
+ self._params.maxTautomers = self.MAX_TAUTOMERS
281
+ self._params.maxTransforms = self.MAX_TRANSFORMS
282
+ self._params.maxRestarts = self.MAX_RESTARTS
283
+ self._params.preferOrganic = self.PREFER_ORGANIC
284
+ self._params.tautomerRemoveSp3Stereo = False
285
+ return self._params
286
+
287
+ @property
288
+ def canon_taut(self):
289
+ """Return whether tautomer canonicalization will be done."""
290
+ return self._canon_taut
291
+
292
+ @property
293
+ def metal_disconnect(self):
294
+ """Return whether metallorganic complexes will be disconnected."""
295
+ return self._metal_disconnect
296
+
297
+ @property
298
+ def taut_enumerator(self):
299
+ """Return the TautomerEnumerator object."""
300
+ if self._taut_enumerator is None:
301
+ self._taut_enumerator = rdMolStandardize.TautomerEnumerator(self.params)
302
+ return self._taut_enumerator
303
+
304
+ @property
305
+ def uncharger(self):
306
+ """Return the Uncharger object."""
307
+ if self._uncharger is None:
308
+ self._uncharger = rdMolStandardize.Uncharger()
309
+ return self._uncharger
310
+
311
+ @property
312
+ def lfrag_chooser(self):
313
+ """Return the LargestFragmentChooser object."""
314
+ if self._lfrag_chooser is None:
315
+ self._lfrag_chooser = rdMolStandardize.LargestFragmentChooser(
316
+ self.params.preferOrganic
317
+ )
318
+ return self._lfrag_chooser
319
+
320
+ @property
321
+ def metal_disconnector(self):
322
+ """Return the MetalDisconnector object."""
323
+ if self._metal_disconnector is None:
324
+ self._metal_disconnector = rdMolStandardize.MetalDisconnector()
325
+ return self._metal_disconnector
326
+
327
+ @property
328
+ def normalizer(self):
329
+ """Return the Normalizer object."""
330
+ if self._normalizer is None:
331
+ self._normalizer = rdMolStandardize.Normalizer(
332
+ self.params.normalizationsFile, self.params.maxRestarts
333
+ )
334
+ return self._normalizer
335
+
336
+ @property
337
+ def reionizer(self):
338
+ """Return the Reionizer object."""
339
+ if self._reionizer is None:
340
+ self._reionizer = rdMolStandardize.Reionizer(self.params.acidbaseFile)
341
+ return self._reionizer
342
+
343
+ def charge_parent(self, mol_in):
344
+ """Sequentially apply a series of MolStandardize operations:
345
+ * MetalDisconnector
346
+ * Normalizer
347
+ * Reionizer
348
+ * LargestFragmentChooser
349
+ * Uncharger
350
+ The net result is that a desalted, normalized, neutral
351
+ molecule with implicit Hs is returned.
352
+ """
353
+ params = Chem.RemoveHsParameters()
354
+ params.removeAndTrackIsotopes = True
355
+ mol_in = Chem.RemoveHs(mol_in, params, sanitize=False)
356
+ if self._metal_disconnect:
357
+ mol_in = self.metal_disconnector.Disconnect(mol_in)
358
+ normalized = self.normalizer.normalize(mol_in)
359
+ Chem.SanitizeMol(normalized)
360
+ normalized = self.reionizer.reionize(normalized)
361
+ Chem.AssignStereochemistry(normalized)
362
+ normalized = self.lfrag_chooser.choose(normalized)
363
+ normalized = self.uncharger.uncharge(normalized)
364
+ # need this to reassess aromaticity on things like
365
+ # cyclopentadienyl, tropylium, azolium, etc.
366
+ Chem.SanitizeMol(normalized)
367
+ return Chem.RemoveHs(Chem.AddHs(normalized))
368
+
369
+ def standardize_mol(self, mol_in):
370
+ """
371
+ Standardize a single molecule.
372
+ :param mol_in: a Chem.Mol
373
+ :return: * (standardized Chem.Mol, n_taut) tuple
374
+ if success. n_taut will be negative if
375
+ tautomer enumeration was aborted due
376
+ to reaching a limit
377
+ * (None, error_msg) if failure
378
+ This calls self.charge_parent() and, if self._canon_taut
379
+ is True, runs tautomer canonicalization.
380
+ """
381
+ n_tautomers = 0
382
+ if isinstance(mol_in, Chem.Mol):
383
+ name = None
384
+ try:
385
+ name = mol_in.GetProp("_Name")
386
+ except KeyError:
387
+ pass
388
+ if not name:
389
+ name = "NONAME"
390
+ else:
391
+ error = f"Expected SMILES or Chem.Mol as input, got {str(type(mol_in))}"
392
+ return None, error
393
+ try:
394
+ mol_out = self.charge_parent(mol_in)
395
+ except Exception as e:
396
+ error = f"charge_parent FAILED: {str(e).strip()}"
397
+ return None, error
398
+ if self._canon_taut:
399
+ try:
400
+ res = self.taut_enumerator.Enumerate(mol_out, False)
401
+ except TypeError:
402
+ # we are still on the pre-2021 RDKit API
403
+ res = self.taut_enumerator.Enumerate(mol_out)
404
+ except Exception as e:
405
+ # something else went wrong
406
+ error = f"canon_taut FAILED: {str(e).strip()}"
407
+ return None, error
408
+ n_tautomers = len(res)
409
+ if hasattr(res, "status"):
410
+ completed = (
411
+ res.status == rdMolStandardize.TautomerEnumeratorStatus.Completed
412
+ )
413
+ else:
414
+ # we are still on the pre-2021 RDKit API
415
+ completed = len(res) < 1000
416
+ if not completed:
417
+ n_tautomers = -n_tautomers
418
+ try:
419
+ mol_out = self.taut_enumerator.PickCanonical(res)
420
+ except AttributeError:
421
+ # we are still on the pre-2021 RDKit API
422
+ mol_out = max(
423
+ [(self.taut_enumerator.ScoreTautomer(m), m) for m in res]
424
+ )[1]
425
+ except Exception as e:
426
+ # something else went wrong
427
+ error = f"canon_taut FAILED: {str(e).strip()}"
428
+ return None, error
429
+ mol_out.SetProp("_Name", name)
430
+ return mol_out, n_tautomers
431
+
432
+
433
+ def load_pickle(path: str):
434
+ with open(path, "rb") as file:
435
+ content = pickle.load(file)
436
+ return content
437
+
438
+
439
+ def write_pickle(path: str, obj: object):
440
+ with open(path, "wb") as file:
441
+ pickle.dump(obj, file)