antoniaebner commited on
Commit
593848b
·
1 Parent(s): 6eb59be

update pipeline

Browse files
Files changed (5) hide show
  1. predict.py +25 -17
  2. src/data.py +74 -170
  3. src/model.py +17 -7
  4. src/preprocess.py +405 -0
  5. src/utils.py +2 -0
predict.py CHANGED
@@ -8,13 +8,14 @@ SMILES and target names as keys.
8
  # Dependencies
9
  from collections import defaultdict
10
 
11
- from src.data import preprocess_molecules
12
- from src.model import Tox21RFClassifier
 
13
 
14
  # ---------------------------------------------------------------------------------------
15
 
16
 
17
- def predict(smiles_list: list[str]) -> dict:
18
  """Applies the classifier to a list of SMILES strings. Returns prediction=0.0 for
19
  any molecule that could not be cleaned.
20
 
@@ -26,29 +27,36 @@ def predict(smiles_list: list[str]) -> dict:
26
  """
27
  print(f"Received {len(smiles_list)} SMILES strings")
28
  # preprocessing pipeline
29
- features, removed_idxs = preprocess_molecules(
 
 
 
 
 
 
 
 
30
  smiles_list,
31
- load_ecdf_path="assets/ecdfs.pkl",
32
- load_scaler_path="assets/scaler.pkl",
 
33
  )
34
- print(f"{len(removed_idxs)} molecules removed during cleaning")
 
35
 
36
  # setup model
37
  model = Tox21RFClassifier(seed=42)
38
- model.load_model("assets/rf_alltasks.joblib")
 
 
39
 
40
  # make predicitons
41
  predictions = defaultdict(dict)
42
- # make smiles list with same num_samples as features
43
- clean_smiles = [smi for i, smi in enumerate(smiles_list) if i not in removed_idxs]
44
- no_pred_smiles = [smi for i, smi in enumerate(smiles_list) if i in removed_idxs]
45
 
46
  for target in model.tasks:
47
  target_pred = model.predict(target, features)
48
- for i, smiles in enumerate(clean_smiles):
49
- predictions[smiles][target] = target_pred[i]
50
-
51
- for smiles in no_pred_smiles:
52
- predictions[smiles][target] = 0.0
53
-
54
  return predictions
 
8
  # Dependencies
9
  from collections import defaultdict
10
 
11
+ from .data import create_descriptors
12
+ from .utils import load_pickle, KNOWN_DESCR
13
+ from .model import Tox21RFClassifier
14
 
15
  # ---------------------------------------------------------------------------------------
16
 
17
 
18
+ def predict(smiles_list: list[str]) -> dict[str, dict[str, float]]:
19
  """Applies the classifier to a list of SMILES strings. Returns prediction=0.0 for
20
  any molecule that could not be cleaned.
21
 
 
27
  """
28
  print(f"Received {len(smiles_list)} SMILES strings")
29
  # preprocessing pipeline
30
+ ecdfs_path = "assets/ecdfs.pkl"
31
+ scaler_path = "assets/scaler.pkl"
32
+ ecdfs = load_pickle(ecdfs_path)
33
+ scaler = load_pickle(scaler_path)
34
+ print(f"Loaded ecdfs from {ecdfs_path}")
35
+ print(f"Loaded scaler from {scaler_path}")
36
+
37
+ descriptors = KNOWN_DESCR
38
+ features, mol_mask = create_descriptors(
39
  smiles_list,
40
+ ecdfs=ecdfs,
41
+ scaler=scaler,
42
+ descriptors=descriptors,
43
  )
44
+ print(f"Created descriptors {descriptors} for molecules.")
45
+ print(f"{len(mol_mask) - sum(mol_mask)} molecules removed during cleaning")
46
 
47
  # setup model
48
  model = Tox21RFClassifier(seed=42)
49
+ model_path = "assets/rf_alltasks.joblib"
50
+ model.load_model(model_path)
51
+ print(f"Loaded model from {model_path}")
52
 
53
  # make predicitons
54
  predictions = defaultdict(dict)
55
+ # create a list with same length as smiles_list to obtain indices for respective features
56
+ feat_indices = np.cumsum(mol_mask) - 1
 
57
 
58
  for target in model.tasks:
59
  target_pred = model.predict(target, features)
60
+ for smiles, is_clean, i in zip(smiles_list, mol_mask, feat_indices):
61
+ predictions[smiles][target] = float(target_pred[i]) if is_clean else 0.0
 
 
 
 
62
  return predictions
src/data.py CHANGED
@@ -7,8 +7,10 @@ SMILES and target names as keys.
7
  """
8
 
9
  import os
 
10
 
11
  import numpy as np
 
12
 
13
  from sklearn.preprocessing import StandardScaler
14
  from statsmodels.distributions.empirical_distribution import ECDF
@@ -17,177 +19,79 @@ from rdkit import Chem, DataStructs
17
  from rdkit.Chem import Descriptors, rdFingerprintGenerator
18
  from rdkit.Chem.rdchem import Mol
19
 
20
- from .utils import USED_200_DESCR, Standardizer, load_pickle, write_pickle
21
-
22
-
23
- def preprocess_molecules(
24
- smiles_list: list[str],
25
- load_ecdf_path: str = "",
26
- load_scaler_path: str = "",
27
- save_ecdf_path: str = "",
28
- save_scaler_path: str = "",
29
- ) -> tuple[np.ndarray, list[int]]:
30
- """Preprocessing pipeline for a list of molecules.
31
-
32
- Args:
33
- smiles_list (list[str]): list of SMILES
34
- load_ecdf_path (str, optional): Path to load ECDFs from. Defaults to "".
35
- load_scaler_path (str, optional): Path to load fitted StandardScaler from. Defaults to "".
36
- save_ecdf_path (str, optional): Path to save calculated ECDFs. Defaults to "".
37
- save_scaler_path (str, optional): Path to save fitted StandardScaler. Defaults to "".
38
-
39
- Returns:
40
- np.ndarray: normalized ECFPs fingerprints and RDKit descriptor quantiles
41
- list[int]: list of removed indices of molecules that could not be cleaned
42
- """
43
-
44
- assert not (
45
- load_ecdf_path and save_ecdf_path
46
- ), "Cannot pass 'load_ecdf_path' and 'save_ecdf_path' simultaneously"
47
- assert not (
48
- load_scaler_path and save_scaler_path
49
- ), "Cannot pass 'load_scaler_path' and 'save_scaler_path' simultaneously"
50
-
51
- ecdfs = (
52
- load_pickle(load_ecdf_path)
53
- if load_ecdf_path and os.path.exists(load_ecdf_path)
54
- else None
 
 
 
 
 
 
 
 
 
 
 
 
55
  )
56
- scaler = (
57
- load_pickle(load_scaler_path)
58
- if load_scaler_path and os.path.exists(load_scaler_path)
59
- else None
60
- )
61
-
62
- # Create cleanded rdkit mol objects
63
- mols, removed_idxs = create_cleaned_mol_objects(smiles_list)
64
- print("Cleaned molecules")
65
-
66
- # Create fingerprints and descriptors
67
- ecfps = create_ecfp_fps(mols)
68
- print("Created ECFP fingerprints")
69
- rdkit_descrs = create_rdkit_descriptors(mols)
70
- print("Created RDKit descriptors")
71
-
72
- # Create and save ecdfs
73
- if ecdfs is None:
74
- print("Create ECDFs")
75
- ecdfs = []
76
- for column in range(rdkit_descrs.shape[1]):
77
- raw_values = rdkit_descrs[:, column].reshape(-1)
78
- ecdfs.append(ECDF(raw_values))
79
- if save_ecdf_path:
80
- write_pickle(save_ecdf_path, ecdfs)
81
- print(f"Saved ECDFs under {save_ecdf_path}")
82
-
83
- # Create quantiles
84
- rdkit_descr_quantiles = create_quantiles(rdkit_descrs, ecdfs)
85
- print("Created quantiles of RDKit descriptors")
86
-
87
- # Concatenate features
88
- raw_features = np.concatenate((ecfps, rdkit_descr_quantiles), axis=1)
89
-
90
- if scaler is None:
91
- scaler = StandardScaler()
92
- scaler.fit(raw_features)
93
- print("Fitted the StandardScaler")
94
- if save_scaler_path:
95
- write_pickle(save_scaler_path, scaler)
96
- print(f"Saved the StandardScaler under {save_scaler_path}")
97
-
98
- # Normalize feature vectors
99
- normalized_features = scaler.transform(raw_features)
100
- print("Normalized the molecule features")
101
-
102
- return normalized_features, removed_idxs
103
-
104
-
105
- def create_cleaned_mol_objects(smiles: list[str]) -> list[Mol]:
106
- """This function creates cleaned RDKit mol objects from a list of SMILES.
107
-
108
- Args:
109
- smiles (list[str]): list of SMILES
110
-
111
- Returns:
112
- list[Mol]: list of cleaned molecules
113
- """
114
- sm = Standardizer(canon_taut=True)
115
-
116
- removed_idxs = list()
117
- mols = list()
118
- for i, smile in enumerate(smiles):
119
- mol = Chem.MolFromSmiles(smile)
120
- standardized_mol, _ = sm.standardize_mol(mol)
121
- if standardized_mol is None:
122
- removed_idxs.append(i)
123
- continue
124
- can_mol = Chem.MolFromSmiles(Chem.MolToSmiles(standardized_mol))
125
- mols.append(can_mol)
126
- return mols, removed_idxs
127
-
128
-
129
- def create_ecfp_fps(mols: list[Mol]) -> np.ndarray:
130
- """This function ECFP fingerprints for a list of molecules.
131
-
132
- Args:
133
- mols (list[Mol]): list of molecules
134
 
135
- Returns:
136
- np.ndarray: ECFP fingerprints of molecules
137
- """
138
- ecfps = list()
139
-
140
- for mol in mols:
141
- fp_sparse_vec = rdFingerprintGenerator.GetCountFPs(
142
- [mol], fpType=rdFingerprintGenerator.MorganFP
143
- )[0]
144
- fp = np.zeros((0,), np.int8)
145
- DataStructs.ConvertToNumpyArray(fp_sparse_vec, fp)
146
-
147
- ecfps.append(fp)
148
-
149
- return np.array(ecfps)
150
-
151
-
152
- def create_rdkit_descriptors(mols: list[Mol]) -> np.ndarray:
153
- """This function creates RDKit descriptors for a list of molecules.
154
-
155
- Args:
156
- mols (list[Mol]): list of molecules
157
-
158
- Returns:
159
- np.ndarray: RDKit descriptors of molecules
160
- """
161
- rdkit_descriptors = list()
162
-
163
- for mol in mols:
164
- descrs = []
165
- for _, descr_calc_fn in Descriptors._descList:
166
- descrs.append(descr_calc_fn(mol))
167
-
168
- descrs = np.array(descrs)
169
- descrs = descrs[USED_200_DESCR]
170
- rdkit_descriptors.append(descrs)
171
-
172
- return np.array(rdkit_descriptors)
173
-
174
-
175
- def create_quantiles(raw_features: np.ndarray, ecdfs: list) -> np.ndarray:
176
- """Create quantile values for given features using the columns
177
-
178
- Args:
179
- raw_features (np.ndarray): values to put into quantiles
180
- ecdfs (list): ECDFs to use
181
-
182
- Returns:
183
- np.ndarray: computed quantiles
184
- """
185
- quantiles = np.zeros_like(raw_features)
186
 
187
- for column in range(raw_features.shape[1]):
188
- raw_values = raw_features[:, column].reshape(-1)
189
- ecdf = ecdfs[column]
190
- q = ecdf(raw_values)
191
- quantiles[:, column] = q
192
 
193
- return quantiles
 
 
 
 
7
  """
8
 
9
  import os
10
+ from typing import Iterable, Literal
11
 
12
  import numpy as np
13
+ import torch
14
 
15
  from sklearn.preprocessing import StandardScaler
16
  from statsmodels.distributions.empirical_distribution import ECDF
 
19
  from rdkit.Chem import Descriptors, rdFingerprintGenerator
20
  from rdkit.Chem.rdchem import Mol
21
 
22
+ from .utils import USED_200_DESCR, Standardizer, load_pickle, write_pickle, KNOWN_DESCR
23
+ from .preprocess import normalize_features
24
+
25
+
26
+ def get_descriptor_dataset(
27
+ data_path: str,
28
+ descriptors: Iterable[str] | Literal["all"],
29
+ scaler=None,
30
+ save_scaler_path: str = "data/scaler.pkl",
31
+ verbose=True,
32
+ normalize=True,
33
+ ):
34
+ if descriptors == "all":
35
+ descriptors = KNOWN_DESCR
36
+
37
+ assert isinstance(descriptors, Iterable), "Passed descriptors are not iterable!"
38
+ assert all(
39
+ [descr in KNOWN_DESCR for descr in descriptors]
40
+ ), f"Passed descriptors contains unknown descriptor types. Allowed descriptors: {KNOWN_DESCR}"
41
+
42
+ datafile = np.load(data_path)
43
+
44
+ if not isinstance(datafile, np.ndarray):
45
+ # concatenate all descriptors and normalize
46
+ data = np.concatenate([datafile[descr] for descr in descriptors], axis=1)
47
+ labels = datafile["labels"]
48
+
49
+ else:
50
+ print("NPY file passed, cannot select specific descriptors")
51
+ data, labels = datafile[:, :-12], datafile[:, -12:]
52
+
53
+ if normalize:
54
+ data, scaler = normalize_features(
55
+ data,
56
+ scaler=scaler,
57
+ save_scaler_path=save_scaler_path,
58
+ verbose=verbose,
59
+ )
60
+
61
+ # filter out unsanitized molecules
62
+ mask = ~np.isnan(data).any(axis=1)
63
+ data = data[mask]
64
+ labels = labels[mask]
65
+
66
+ assert data.shape[0] == labels.shape[0], (
67
+ f"Mismatch between data and labels: "
68
+ f"data has {data.shape[0]} samples, but labels has {labels.shape[0]} samples."
69
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ return (data, labels, scaler)
72
+
73
+
74
+ def get_torch_descriptor_dataset(
75
+ data_path: str,
76
+ descriptors: list[str],
77
+ scaler=None,
78
+ save_scaler_path: str = "data/scaler.pkl",
79
+ nan_to_num: int = -100,
80
+ verbose=True,
81
+ normalize=True,
82
+ ) -> torch.utils.data.TensorDataset:
83
+ data, labels, scaler = get_descriptor_dataset(
84
+ data_path,
85
+ descriptors,
86
+ scaler,
87
+ save_scaler_path,
88
+ verbose=verbose,
89
+ normalize=normalize,
90
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ labels = np.nan_to_num(labels, nan=nan_to_num)
 
 
 
 
93
 
94
+ dataset = torch.utils.data.TensorDataset(
95
+ torch.FloatTensor(data), torch.LongTensor(labels)
96
+ )
97
+ return dataset, scaler
src/model.py CHANGED
@@ -19,17 +19,27 @@ from .utils import TASKS
19
  class Tox21RFClassifier:
20
  """A random forest classifier that assigns a toxicity score to a given SMILES string."""
21
 
22
- def __init__(self, seed: int = 42):
23
  """Initialize a random forest classifier for each of the 12 Tox21 tasks.
24
 
25
  Args:
26
  seed (int, optional): seed for RF to ensure reproducibility. Defaults to 42.
27
  """
28
  self.tasks = TASKS
29
- self.model = {
30
- task: RandomForestClassifier(n_estimators=1001, random_state=seed)
31
- for task in self.tasks
32
- }
 
 
 
 
 
 
 
 
 
 
33
 
34
  def load_model(self, path: str) -> None:
35
  """Loads the model from a given path
@@ -45,8 +55,8 @@ class Tox21RFClassifier:
45
  Args:
46
  path (str): path to save model to
47
  """
48
- if not os.path.exists(os.path.pardir(path)):
49
- os.makedirs(os.path.pardir(path))
50
 
51
  joblib.dump(self.model, path)
52
 
 
19
  class Tox21RFClassifier:
20
  """A random forest classifier that assigns a toxicity score to a given SMILES string."""
21
 
22
+ def __init__(self, seed: int = 42, task_config: dict = None):
23
  """Initialize a random forest classifier for each of the 12 Tox21 tasks.
24
 
25
  Args:
26
  seed (int, optional): seed for RF to ensure reproducibility. Defaults to 42.
27
  """
28
  self.tasks = TASKS
29
+ if task_config is None:
30
+ self.model = {
31
+ task: RandomForestClassifier(
32
+ n_estimators=1000, random_state=seed, n_jobs=8
33
+ )
34
+ for task in self.tasks
35
+ }
36
+ else:
37
+ self.model = {
38
+ task: RandomForestClassifier(
39
+ **task_config[task], random_state=seed, n_jobs=8
40
+ )
41
+ for task in self.tasks
42
+ }
43
 
44
  def load_model(self, path: str) -> None:
45
  """Loads the model from a given path
 
55
  Args:
56
  path (str): path to save model to
57
  """
58
+ if not os.path.exists(os.path.dirname(path)):
59
+ os.makedirs(os.path.dirname(path))
60
 
61
  joblib.dump(self.model, path)
62
 
src/preprocess.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pipeline taken from https://huggingface.co/spaces/ml-jku/mhnfs/blob/main/src/data_preprocessing/create_descriptors.py
2
+
3
+ """
4
+ This files includes a the data processing for Tox21.
5
+ As an input it takes a list of SMILES and it outputs a nested dictionary with
6
+ SMILES and target names as keys.
7
+ """
8
+
9
+ import os
10
+ import argparse
11
+ import json
12
+ from typing import Iterable
13
+
14
+ import numpy as np
15
+ import pandas as pd
16
+
17
+ from sklearn.preprocessing import StandardScaler
18
+ from statsmodels.distributions.empirical_distribution import ECDF
19
+ from datasets import load_dataset
20
+
21
+ from rdkit import Chem, DataStructs
22
+ from rdkit.Chem import Descriptors, rdFingerprintGenerator, MACCSkeys
23
+ from rdkit.Chem.rdchem import Mol
24
+
25
+ from .utils import (
26
+ TASKS,
27
+ KNOWN_DESCR,
28
+ HF_TOKEN,
29
+ USED_200_DESCR,
30
+ Standardizer,
31
+ load_pickle,
32
+ write_pickle,
33
+ )
34
+
35
+ parser = argparse.ArgumentParser(
36
+ description="Data preprocessing script for the Tox21 dataset"
37
+ )
38
+
39
+ parser.add_argument(
40
+ "--save_folder",
41
+ type=str,
42
+ default="data/",
43
+ )
44
+
45
+ parser.add_argument(
46
+ "--use_hf",
47
+ type=int,
48
+ default=0,
49
+ )
50
+
51
+ parser.add_argument(
52
+ "--path_ecdfs",
53
+ type=str,
54
+ default="data/ecdfs.pkl",
55
+ )
56
+
57
+ parser.add_argument(
58
+ "--tox_smarts_filepath",
59
+ type=str,
60
+ default="data/tox_smarts.json",
61
+ )
62
+
63
+
64
+ def create_cleaned_mol_objects(smiles: list[str]) -> tuple[list[Mol], np.ndarray]:
65
+ """This function creates cleaned RDKit mol objects from a list of SMILES.
66
+
67
+ Args:
68
+ smiles (list[str]): list of SMILES
69
+
70
+ Returns:
71
+ list[Mol]: list of cleaned molecules
72
+ np.ndarray[bool]: mask that contains False at index `i`, if molecule in `smiles` at
73
+ index `i` could not be cleaned and was removed.
74
+ """
75
+ sm = Standardizer(canon_taut=True)
76
+
77
+ clean_mol_mask = list()
78
+ mols = list()
79
+ for i, smile in enumerate(smiles):
80
+ mol = Chem.MolFromSmiles(smile)
81
+ standardized_mol, _ = sm.standardize_mol(mol)
82
+ is_cleaned = standardized_mol is not None
83
+ clean_mol_mask.append(is_cleaned)
84
+ if not is_cleaned:
85
+ continue
86
+ can_mol = Chem.MolFromSmiles(Chem.MolToSmiles(standardized_mol))
87
+ mols.append(can_mol)
88
+
89
+ return mols, np.array(clean_mol_mask)
90
+
91
+
92
+ def create_ecfp_fps(mols: list[Mol]) -> np.ndarray:
93
+ """This function ECFP fingerprints for a list of molecules.
94
+
95
+ Args:
96
+ mols (list[Mol]): list of molecules
97
+
98
+ Returns:
99
+ np.ndarray: ECFP fingerprints of molecules
100
+ """
101
+ ecfps = list()
102
+
103
+ for mol in mols:
104
+ fp_sparse_vec = rdFingerprintGenerator.GetCountFPs(
105
+ [mol], fpType=rdFingerprintGenerator.MorganFP
106
+ )[0]
107
+ fp = np.zeros((0,), np.int8)
108
+ DataStructs.ConvertToNumpyArray(fp_sparse_vec, fp)
109
+
110
+ ecfps.append(fp)
111
+
112
+ return np.array(ecfps)
113
+
114
+
115
+ def create_maccs_keys(mols: list[Mol]) -> np.ndarray:
116
+ maccs = [MACCSkeys.GenMACCSKeys(x) for x in mols]
117
+ return np.array(maccs)
118
+
119
+
120
+ def get_tox_patterns(filepath: str):
121
+ """This calculates tox features defined in tox_smarts.json.
122
+ Args:
123
+ mols: A list of Mol
124
+ n_jobs: If >1 multiprocessing is used
125
+ """
126
+ # load patterns
127
+ with open(filepath) as f:
128
+ smarts_list = [s[1] for s in json.load(f)]
129
+
130
+ # Code does not work for this case
131
+ assert len([s for s in smarts_list if ("AND" in s) and ("OR" in s)]) == 0
132
+
133
+ # Chem.MolFromSmarts takes a long time so it pays of to parse all the smarts first
134
+ # and then use them for all molecules. This gives a huge speedup over existing code.
135
+ # a list of patterns, whether to negate the match result and how to join them to obtain one boolean value
136
+ all_patterns = []
137
+ for smarts in smarts_list:
138
+ patterns = [] # list of smarts-patterns
139
+ # value for each of the patterns above. Negates the values of the above later.
140
+ negations = []
141
+
142
+ if " AND " in smarts:
143
+ smarts = smarts.split(" AND ")
144
+ merge_any = False # If an ' AND ' is found all 'subsmarts' have to match
145
+ else:
146
+ # If there is an ' OR ' present it's enough is any of the 'subsmarts' match.
147
+ # This also accumulates smarts where neither ' OR ' nor ' AND ' occur
148
+ smarts = smarts.split(" OR ")
149
+ merge_any = True
150
+
151
+ # for all subsmarts check if they are preceded by 'NOT '
152
+ for s in smarts:
153
+ neg = s.startswith("NOT ")
154
+ if neg:
155
+ s = s[4:]
156
+ patterns.append(Chem.MolFromSmarts(s))
157
+ negations.append(neg)
158
+
159
+ all_patterns.append((patterns, negations, merge_any))
160
+ return all_patterns
161
+
162
+
163
+ def create_tox_features(mols: list[Mol], patterns: list) -> np.ndarray:
164
+ """Matches the tox patterns against a molecule. Returns a boolean array"""
165
+ tox_data = []
166
+ for mol in mols:
167
+ mol_features = []
168
+ for patts, negations, merge_any in patterns:
169
+ matches = [mol.HasSubstructMatch(p) for p in patts]
170
+ matches = [m != n for m, n in zip(matches, negations)]
171
+ if merge_any:
172
+ pres = any(matches)
173
+ else:
174
+ pres = all(matches)
175
+ mol_features.append(pres)
176
+
177
+ tox_data.append(np.array(mol_features))
178
+
179
+ return np.array(tox_data)
180
+
181
+
182
+ def create_rdkit_descriptors(mols: list[Mol]) -> np.ndarray:
183
+ """This function creates RDKit descriptors for a list of molecules.
184
+
185
+ Args:
186
+ mols (list[Mol]): list of molecules
187
+
188
+ Returns:
189
+ np.ndarray: RDKit descriptors of molecules
190
+ """
191
+ rdkit_descriptors = list()
192
+
193
+ for mol in mols:
194
+ descrs = []
195
+ for _, descr_calc_fn in Descriptors._descList:
196
+ descrs.append(descr_calc_fn(mol))
197
+
198
+ descrs = np.array(descrs)
199
+ descrs = descrs[USED_200_DESCR]
200
+ rdkit_descriptors.append(descrs)
201
+
202
+ return np.array(rdkit_descriptors)
203
+
204
+
205
+ def create_quantiles(raw_features: np.ndarray, ecdfs: list) -> np.ndarray:
206
+ """Create quantile values for given features using the columns
207
+
208
+ Args:
209
+ raw_features (np.ndarray): values to put into quantiles
210
+ ecdfs (list): ECDFs to use
211
+
212
+ Returns:
213
+ np.ndarray: computed quantiles
214
+ """
215
+ quantiles = np.zeros_like(raw_features)
216
+
217
+ for column in range(raw_features.shape[1]):
218
+ raw_values = raw_features[:, column].reshape(-1)
219
+ ecdf = ecdfs[column]
220
+ q = ecdf(raw_values)
221
+ quantiles[:, column] = q
222
+
223
+ return quantiles
224
+
225
+
226
+ def fill(features, mask, value=np.nan):
227
+ n_mols = len(mask)
228
+ n_features = features.shape[1]
229
+
230
+ data = np.zeros(shape=(n_mols, n_features))
231
+ data.fill(value)
232
+ data[~mask] = features
233
+ return data
234
+
235
+
236
+ def normalize_features(
237
+ raw_features,
238
+ scaler=None,
239
+ save_scaler_path: str = "",
240
+ verbose=True,
241
+ ):
242
+ if scaler is None:
243
+ scaler = StandardScaler()
244
+ scaler.fit(raw_features)
245
+ if verbose:
246
+ print("Fitted the StandardScaler")
247
+ if save_scaler_path:
248
+ write_pickle(save_scaler_path, scaler)
249
+ if verbose:
250
+ print(f"Saved the StandardScaler under {save_scaler_path}")
251
+
252
+ # Normalize feature vectors
253
+ normalized_features = scaler.transform(raw_features)
254
+ if verbose:
255
+ print("Normalized molecule features")
256
+ return normalized_features, scaler
257
+
258
+
259
+ def create_descriptors(
260
+ smiles,
261
+ ecdfs=None,
262
+ scaler=None,
263
+ descriptors: Iterable = KNOWN_DESCR,
264
+ ):
265
+ # Create cleanded rdkit mol objects
266
+ mols, clean_mol_mask = create_cleaned_mol_objects(smiles)
267
+ print("Cleaned molecules")
268
+
269
+ features = []
270
+ if "ecfps" in descriptors:
271
+ # Create fingerprints and descriptors
272
+ ecfps = create_ecfp_fps(mols)
273
+ # expand using mol_mask
274
+ ecfps = fill(ecfps, ~clean_mol_mask)
275
+ features.append(ecfps)
276
+ print("Created ECFP fingerprints")
277
+
278
+ if "rdkit_descr_quantiles" in descriptors:
279
+ rdkit_descrs = create_rdkit_descriptors(mols)
280
+ print("Created RDKit descriptors")
281
+
282
+ # Create and save ecdfs
283
+ if ecdfs is None:
284
+ print("Create ECDFs")
285
+ ecdfs = []
286
+ for column in range(rdkit_descrs.shape[1]):
287
+ raw_values = rdkit_descrs[:, column].reshape(-1)
288
+ ecdfs.append(ECDF(raw_values))
289
+
290
+ # Create quantiles
291
+ rdkit_descr_quantiles = create_quantiles(rdkit_descrs, ecdfs)
292
+ # expand using mol_mask
293
+ rdkit_descr_quantiles = fill(rdkit_descr_quantiles, ~clean_mol_mask)
294
+ features.append(rdkit_descr_quantiles)
295
+ print("Created quantiles of RDKit descriptors")
296
+
297
+ if "maccs" in descriptors:
298
+ maccs = create_maccs_keys(mols)
299
+ maccs = fill(maccs, ~clean_mol_mask)
300
+ features.append(maccs)
301
+ print("Created MACCS keys")
302
+
303
+ if "tox" in descriptors:
304
+ tox_patterns = get_tox_patterns("assets/tox_smarts.json")
305
+ tox = create_tox_features(mols, tox_patterns)
306
+ tox = fill(tox, ~clean_mol_mask)
307
+ features.append(tox)
308
+ print("Created Tox features")
309
+
310
+ # concatenate features
311
+ raw_features = np.concatenate(features, axis=1)
312
+
313
+ # normalize with scaler if scaler is passed, else create scaler
314
+ features, _ = normalize_features(
315
+ raw_features,
316
+ scaler=scaler,
317
+ verbose=True,
318
+ )
319
+
320
+ return features, clean_mol_mask
321
+
322
+
323
+ def main(args):
324
+ splits = ["train", "validation"]
325
+ ds = load_dataset("tschouis/tox21", token=HF_TOKEN)
326
+
327
+ for split in splits:
328
+
329
+ print(f"Preprocess {split} molecules")
330
+ smiles = list(ds[split]["smiles"])
331
+
332
+ # Create cleanded rdkit mol objects
333
+ mols, clean_mol_mask = create_cleaned_mol_objects(smiles)
334
+ print("Cleaned molecules")
335
+
336
+ tox_patterns = get_tox_patterns(args.tox_smarts_filepath)
337
+
338
+ # Create fingerprints and descriptors
339
+ ecfps = create_ecfp_fps(mols)
340
+ # expand using mol_mask
341
+ ecfps = fill(ecfps, ~clean_mol_mask)
342
+ print("Created ECFP fingerprints")
343
+
344
+ rdkit_descrs = create_rdkit_descriptors(mols)
345
+ print("Created RDKit descriptors")
346
+
347
+ # Create and save ecdfs
348
+ if split == "train":
349
+ print("Create ECDFs")
350
+ ecdfs = []
351
+ for column in range(rdkit_descrs.shape[1]):
352
+ raw_values = rdkit_descrs[:, column].reshape(-1)
353
+ ecdfs.append(ECDF(raw_values))
354
+
355
+ write_pickle(args.path_ecdfs, ecdfs)
356
+ print(f"Saved ECDFs under {args.path_ecdfs}")
357
+ else:
358
+ print(f"Load ECDFs from {args.path_ecdfs}")
359
+ ecdfs = load_pickle(args.path_ecdfs)
360
+
361
+ # Create quantiles
362
+ rdkit_descr_quantiles = create_quantiles(rdkit_descrs, ecdfs)
363
+ # expand using mol_mask
364
+ rdkit_descr_quantiles = fill(rdkit_descr_quantiles, ~clean_mol_mask)
365
+ print("Created quantiles of RDKit descriptors")
366
+
367
+ maccs = create_maccs_keys(mols)
368
+ maccs = fill(maccs, ~clean_mol_mask)
369
+ print("Created MACCS keys")
370
+
371
+ tox = create_tox_features(mols, tox_patterns)
372
+ tox = fill(tox, ~clean_mol_mask)
373
+ print("Created Tox features")
374
+
375
+ labels = []
376
+ for task in TASKS:
377
+ datasplit = ds[split].to_pandas() if args.use_hf else ds[split]
378
+ labels.append(datasplit[task].to_numpy())
379
+ labels = np.stack(labels, axis=1)
380
+
381
+ save_path = os.path.join(args.save_folder, f"tox21_{split}.npz")
382
+ with open(save_path, "wb") as f:
383
+ np.savez(
384
+ f,
385
+ labels=labels,
386
+ ecfps=ecfps,
387
+ rdkit_descr_quantiles=rdkit_descr_quantiles,
388
+ maccs=maccs,
389
+ tox=tox,
390
+ )
391
+ print(f"Saved preprocessed {split} split under {save_path}")
392
+
393
+ print("Preprocessing finished successfully")
394
+
395
+
396
+ if __name__ == "__main__":
397
+ args = parser.parse_args()
398
+
399
+ if not os.path.exists(args.save_folder):
400
+ os.makedirs(args.save_folder)
401
+
402
+ if not os.path.exists(os.path.dirname(args.path_ecdfs)):
403
+ os.makedirs(os.path.dirname(args.path_ecdfs))
404
+
405
+ main(args)
src/utils.py CHANGED
@@ -28,6 +28,8 @@ TASKS = [
28
  "SR-p53",
29
  ]
30
 
 
 
31
  USED_200_DESCR = [
32
  0,
33
  1,
 
28
  "SR-p53",
29
  ]
30
 
31
+ KNOWN_DESCR = ["ecfps", "rdkit_descr_quantiles", "maccs", "tox"]
32
+
33
  USED_200_DESCR = [
34
  0,
35
  1,