antoniaebner commited on
Commit
a8d912f
·
1 Parent(s): 81226cb

add argparsing to train.py; add docstrings; adapt Tox21RFClassifier save and load functions

Browse files
Files changed (4) hide show
  1. data.py +54 -19
  2. model.py +36 -20
  3. predict.py +10 -5
  4. train.py +47 -12
data.py CHANGED
@@ -7,7 +7,6 @@ SMILES and target names as keys.
7
  """
8
 
9
  import os
10
- from typing import List
11
 
12
  import numpy as np
13
 
@@ -27,8 +26,21 @@ def preprocess_molecules(
27
  load_scaler_path: str = "",
28
  save_ecdf_path: str = "",
29
  save_scaler_path: str = "",
30
- ) -> list[int]:
31
- """preprocess a list of molecules"""
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  assert not (
33
  load_ecdf_path and save_ecdf_path
34
  ), "Cannot pass 'load_ecdf_path' and 'save_ecdf_path' simultaneously"
@@ -68,12 +80,12 @@ def preprocess_molecules(
68
  write_pickle(save_ecdf_path, ecdfs)
69
  print(f"Saved ECDFs under {save_ecdf_path}")
70
 
71
- # Create quantils
72
- rdkit_descr_quantils = create_quantils(rdkit_descrs, ecdfs)
73
  print("Created quantiles of RDKit descriptors")
74
 
75
  # Concatenate features
76
- raw_features = np.concatenate((ecfps, rdkit_descr_quantils), axis=1)
77
 
78
  if scaler is None:
79
  scaler = StandardScaler()
@@ -90,9 +102,14 @@ def preprocess_molecules(
90
  return normalized_features, removed_idxs
91
 
92
 
93
- def create_cleaned_mol_objects(smiles: List[str]) -> List[Mol]:
94
- """
95
- This function creates cleaned RDKit mol objects from a list of SMILES.
 
 
 
 
 
96
  """
97
  sm = Standardizer(canon_taut=True)
98
 
@@ -109,9 +126,14 @@ def create_cleaned_mol_objects(smiles: List[str]) -> List[Mol]:
109
  return mols, removed_idxs
110
 
111
 
112
- def create_ecfp_fps(mols: List[Mol]) -> np.ndarray:
113
- """
114
- This function ECFP fingerprints for a list of molecules.
 
 
 
 
 
115
  """
116
  ecfps = list()
117
 
@@ -127,9 +149,14 @@ def create_ecfp_fps(mols: List[Mol]) -> np.ndarray:
127
  return np.array(ecfps)
128
 
129
 
130
- def create_rdkit_descriptors(mols: List[Mol]) -> np.ndarray:
131
- """
132
- This function creates RDKit descriptors for a list of molecules.
 
 
 
 
 
133
  """
134
  rdkit_descriptors = list()
135
 
@@ -145,14 +172,22 @@ def create_rdkit_descriptors(mols: List[Mol]) -> np.ndarray:
145
  return np.array(rdkit_descriptors)
146
 
147
 
148
- def create_quantils(raw_features: np.ndarray, ecdfs: list) -> np.ndarray:
 
 
 
 
 
149
 
150
- quantils = np.zeros_like(raw_features)
 
 
 
151
 
152
  for column in range(raw_features.shape[1]):
153
  raw_values = raw_features[:, column].reshape(-1)
154
  ecdf = ecdfs[column]
155
  q = ecdf(raw_values)
156
- quantils[:, column] = q
157
 
158
- return quantils
 
7
  """
8
 
9
  import os
 
10
 
11
  import numpy as np
12
 
 
26
  load_scaler_path: str = "",
27
  save_ecdf_path: str = "",
28
  save_scaler_path: str = "",
29
+ ) -> tuple[np.ndarray, list[int]]:
30
+ """Preprocessing pipeline for a list of molecules.
31
+
32
+ Args:
33
+ smiles_list (list[str]): list of SMILES
34
+ load_ecdf_path (str, optional): Path to load ECDFs from. Defaults to "".
35
+ load_scaler_path (str, optional): Path to load fitted StandardScaler from. Defaults to "".
36
+ save_ecdf_path (str, optional): Path to save calculated ECDFs. Defaults to "".
37
+ save_scaler_path (str, optional): Path to save fitted StandardScaler. Defaults to "".
38
+
39
+ Returns:
40
+ np.ndarray: normalized ECFPs fingerprints and RDKit descriptor quantiles
41
+ list[int]: list of removed indices of molecules that could not be cleaned
42
+ """
43
+
44
  assert not (
45
  load_ecdf_path and save_ecdf_path
46
  ), "Cannot pass 'load_ecdf_path' and 'save_ecdf_path' simultaneously"
 
80
  write_pickle(save_ecdf_path, ecdfs)
81
  print(f"Saved ECDFs under {save_ecdf_path}")
82
 
83
+ # Create quantiles
84
+ rdkit_descr_quantiles = create_quantiles(rdkit_descrs, ecdfs)
85
  print("Created quantiles of RDKit descriptors")
86
 
87
  # Concatenate features
88
+ raw_features = np.concatenate((ecfps, rdkit_descr_quantiles), axis=1)
89
 
90
  if scaler is None:
91
  scaler = StandardScaler()
 
102
  return normalized_features, removed_idxs
103
 
104
 
105
+ def create_cleaned_mol_objects(smiles: list[str]) -> list[Mol]:
106
+ """This function creates cleaned RDKit mol objects from a list of SMILES.
107
+
108
+ Args:
109
+ smiles (list[str]): list of SMILES
110
+
111
+ Returns:
112
+ list[Mol]: list of cleaned molecules
113
  """
114
  sm = Standardizer(canon_taut=True)
115
 
 
126
  return mols, removed_idxs
127
 
128
 
129
+ def create_ecfp_fps(mols: list[Mol]) -> np.ndarray:
130
+ """This function ECFP fingerprints for a list of molecules.
131
+
132
+ Args:
133
+ mols (list[Mol]): list of molecules
134
+
135
+ Returns:
136
+ np.ndarray: ECFP fingerprints of molecules
137
  """
138
  ecfps = list()
139
 
 
149
  return np.array(ecfps)
150
 
151
 
152
+ def create_rdkit_descriptors(mols: list[Mol]) -> np.ndarray:
153
+ """This function creates RDKit descriptors for a list of molecules.
154
+
155
+ Args:
156
+ mols (list[Mol]): list of molecules
157
+
158
+ Returns:
159
+ np.ndarray: RDKit descriptors of molecules
160
  """
161
  rdkit_descriptors = list()
162
 
 
172
  return np.array(rdkit_descriptors)
173
 
174
 
175
+ def create_quantiles(raw_features: np.ndarray, ecdfs: list) -> np.ndarray:
176
+ """Create quantile values for given features using the columns
177
+
178
+ Args:
179
+ raw_features (np.ndarray): values to put into quantiles
180
+ ecdfs (list): ECDFs to use
181
 
182
+ Returns:
183
+ np.ndarray: computed quantiles
184
+ """
185
+ quantiles = np.zeros_like(raw_features)
186
 
187
  for column in range(raw_features.shape[1]):
188
  raw_values = raw_features[:, column].reshape(-1)
189
  ecdf = ecdfs[column]
190
  q = ecdf(raw_values)
191
+ quantiles[:, column] = q
192
 
193
+ return quantiles
model.py CHANGED
@@ -17,43 +17,59 @@ from utils import TASKS
17
 
18
  # ---------------------------------------------------------------------------------------
19
  class Tox21RFClassifier:
20
- """
21
- A random forest classifier that assigns a toxicity score to a given SMILES string.
22
- """
23
 
24
  def __init__(self, seed: int = 42):
 
 
 
 
 
25
  self.tasks = TASKS
26
  self.model = {
27
  task: RandomForestClassifier(n_estimators=1001, random_state=seed)
28
  for task in self.tasks
29
  }
30
 
31
- def load_model(self, folder: str):
32
- """
33
- Loads the model from a given model checkpoint
34
- """
35
- self.model = {
36
- task: joblib.load(os.path.join(folder, f"rf_{task}.joblib"))
37
- for task in self.tasks
38
- }
39
 
40
- def save_model(self, folder: str):
 
41
  """
42
- Saves the model to a given folder
 
 
 
 
 
 
43
  """
44
- if not os.path.exists(folder):
45
- os.makedirs(folder)
46
 
47
- for task, model in self.model.items():
48
- joblib.dump(model, os.path.join(folder, f"rf_{task}.joblib"))
49
 
50
  def fit(self, task: str, input_features: np.ndarray, labels: np.ndarray) -> None:
 
 
 
 
 
 
 
51
  assert task in self.tasks, f"Unknown task: {task}"
52
  self.model[task].fit(input_features, labels)
53
 
54
- def predict(self, task: str, features: np.ndarray) -> dict:
55
- """
56
- Predicts a given Tox21 targets for a given np.array of molecule features
 
 
 
 
 
 
57
  """
58
  assert task in self.tasks, f"Unknown task: {task}"
59
  preds = self.model[task].predict_proba(features)
 
17
 
18
  # ---------------------------------------------------------------------------------------
19
  class Tox21RFClassifier:
20
+ """A random forest classifier that assigns a toxicity score to a given SMILES string."""
 
 
21
 
22
  def __init__(self, seed: int = 42):
23
+ """Initialize a random forest classifier for each of the 12 Tox21 tasks.
24
+
25
+ Args:
26
+ seed (int, optional): seed for RF to ensure reproducibility. Defaults to 42.
27
+ """
28
  self.tasks = TASKS
29
  self.model = {
30
  task: RandomForestClassifier(n_estimators=1001, random_state=seed)
31
  for task in self.tasks
32
  }
33
 
34
+ def load_model(self, path: str) -> None:
35
+ """Loads the model from a given path
 
 
 
 
 
 
36
 
37
+ Args:
38
+ path (str): path to model checkpoint
39
  """
40
+ self.model = joblib.load(path)
41
+
42
+ def save_model(self, path: str) -> None:
43
+ """Saves the model to a given path
44
+
45
+ Args:
46
+ path (str): path to save model to
47
  """
48
+ if not os.path.exists(os.path.pardir(path)):
49
+ os.makedirs(os.path.pardir(path))
50
 
51
+ joblib.dump(self.model, path)
 
52
 
53
  def fit(self, task: str, input_features: np.ndarray, labels: np.ndarray) -> None:
54
+ """Train the random forest for a given task
55
+
56
+ Args:
57
+ task (str): task to train
58
+ input_features (np.ndarray): training features
59
+ labels (np.ndarray): training labels
60
+ """
61
  assert task in self.tasks, f"Unknown task: {task}"
62
  self.model[task].fit(input_features, labels)
63
 
64
+ def predict(self, task: str, features: np.ndarray) -> np.ndarray:
65
+ """Predicts labels for a given Tox21 target using molecule features
66
+
67
+ Args:
68
+ task (str): the Tox21 target to predict for
69
+ features (np.ndarray): molecule features used for prediction
70
+
71
+ Returns:
72
+ np.ndarray: predicted probability for positive class
73
  """
74
  assert task in self.tasks, f"Unknown task: {task}"
75
  preds = self.model[task].predict_proba(features)
predict.py CHANGED
@@ -6,7 +6,6 @@ SMILES and target names as keys.
6
 
7
  # ---------------------------------------------------------------------------------------
8
  # Dependencies
9
- from typing import List
10
  from collections import defaultdict
11
 
12
  from data import preprocess_molecules
@@ -15,9 +14,15 @@ from model import Tox21RFClassifier
15
  # ---------------------------------------------------------------------------------------
16
 
17
 
18
- def predict(smiles_list: List[str]) -> dict:
19
- """
20
- Applies the classifier to a list of SMILES strings.
 
 
 
 
 
 
21
  """
22
  # preprocessing pipeline
23
  features, removed_idxs = preprocess_molecules(
@@ -28,7 +33,7 @@ def predict(smiles_list: List[str]) -> dict:
28
 
29
  # setup model
30
  model = Tox21RFClassifier(seed=42)
31
- model.load_model("assets/model/")
32
 
33
  # make predicitons
34
  predictions = defaultdict(dict)
 
6
 
7
  # ---------------------------------------------------------------------------------------
8
  # Dependencies
 
9
  from collections import defaultdict
10
 
11
  from data import preprocess_molecules
 
14
  # ---------------------------------------------------------------------------------------
15
 
16
 
17
+ def predict(smiles_list: list[str]) -> dict:
18
+ """Applies the classifier to a list of SMILES strings. Returns prediction=0.0 for
19
+ any molecule that could not be cleaned.
20
+
21
+ Args:
22
+ smiles_list (list[str]): list of SMILES strings
23
+
24
+ Returns:
25
+ dict: nested prediction dictionary, following {'<smiles>': {'<target>': <pred>}}
26
  """
27
  # preprocessing pipeline
28
  features, removed_idxs = preprocess_molecules(
 
33
 
34
  # setup model
35
  model = Tox21RFClassifier(seed=42)
36
+ model.load_model("assets/rf_alltasks.joblib")
37
 
38
  # make predicitons
39
  predictions = defaultdict(dict)
train.py CHANGED
@@ -2,6 +2,8 @@
2
  Script for fitting and saving any preprocessing assets, as well as the fitted RandomForest model
3
  """
4
 
 
 
5
  import numpy as np
6
 
7
  from tabulate import tabulate
@@ -12,9 +14,43 @@ from data import preprocess_molecules
12
  from model import Tox21RFClassifier
13
  from utils import HF_TOKEN
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- def get_sample_mask(removed_idxs: list[int], labels: np.ndarray):
17
- # mask out NaN labels and labels of removed idxs
18
  task_mask = ~np.isnan(labels)
19
  removed_mask = np.ones_like(labels, dtype=bool)
20
  removed_mask[removed_idxs] = 0
@@ -25,25 +61,23 @@ def get_sample_mask(removed_idxs: list[int], labels: np.ndarray):
25
  return feature_mask, label_mask
26
 
27
 
28
- def main():
29
- # save preprocessing scaler and ecdf distributions
30
- save_folder = "assets/model/"
31
  ds = load_dataset("tschouis/tox21", token=HF_TOKEN)
32
 
33
  print("Preprocess train molecules")
34
  train_smiles = list(ds["train"]["smiles"])
35
  train_features, train_removed_idxs = preprocess_molecules(
36
  train_smiles,
37
- save_ecdf_path="assets/ecdfs.pkl",
38
- save_scaler_path="assets/scaler.pkl",
39
  )
40
 
41
  print("Preprocess validation molecules")
42
  val_smiles = list(ds["validation"]["smiles"])
43
  val_features, val_removed_idxs = preprocess_molecules(
44
  val_smiles,
45
- load_ecdf_path="assets/ecdfs.pkl",
46
- load_scaler_path="assets/scaler.pkl",
47
  )
48
 
49
  model = Tox21RFClassifier(seed=42)
@@ -57,8 +91,8 @@ def main():
57
  task, train_features[feature_mask], task_labels[label_mask].astype(int)
58
  )
59
 
60
- print(f"Save model under {save_folder}")
61
- # model.save_model(save_folder)
62
 
63
  print("Evaluate model")
64
  results = {}
@@ -76,4 +110,5 @@ def main():
76
 
77
 
78
  if __name__ == "__main__":
79
- main()
 
 
2
  Script for fitting and saving any preprocessing assets, as well as the fitted RandomForest model
3
  """
4
 
5
+ import argparse
6
+
7
  import numpy as np
8
 
9
  from tabulate import tabulate
 
14
  from model import Tox21RFClassifier
15
  from utils import HF_TOKEN
16
 
17
+ parser = argparse.ArgumentParser(description="RF Trainig script for Tox21 dataset")
18
+
19
+ parser.add_argument(
20
+ "--save_path_model",
21
+ type=str,
22
+ default="assets/rf_alltasks.joblib",
23
+ )
24
+
25
+ parser.add_argument(
26
+ "--save_path_ecdfs",
27
+ type=str,
28
+ default="assets/ecdfs.pkl",
29
+ )
30
+
31
+ parser.add_argument(
32
+ "--save_path_scaler",
33
+ type=str,
34
+ default="assets/scaler.pkl",
35
+ )
36
+
37
+
38
+ def get_sample_mask(
39
+ removed_idxs: list[int], labels: np.ndarray
40
+ ) -> tuple[np.ndarray, np.ndarray]:
41
+ """Returns two masks, one for the samples and one for the labels.
42
+ Filters out any indices removed from the samples and any indices
43
+ where the label is NaN.
44
+
45
+ Args:
46
+ removed_idxs (list[int]): Indices that were removed from the samples
47
+ labels (np.ndarray): list of labels
48
+
49
+ Returns:
50
+ np.ndarray: Feature mask
51
+ np.ndarray: Label mask
52
+ """
53
 
 
 
54
  task_mask = ~np.isnan(labels)
55
  removed_mask = np.ones_like(labels, dtype=bool)
56
  removed_mask[removed_idxs] = 0
 
61
  return feature_mask, label_mask
62
 
63
 
64
+ def main(args):
 
 
65
  ds = load_dataset("tschouis/tox21", token=HF_TOKEN)
66
 
67
  print("Preprocess train molecules")
68
  train_smiles = list(ds["train"]["smiles"])
69
  train_features, train_removed_idxs = preprocess_molecules(
70
  train_smiles,
71
+ save_ecdf_path=args.save_path_ecdfs,
72
+ save_scaler_path=args.save_path_scaler,
73
  )
74
 
75
  print("Preprocess validation molecules")
76
  val_smiles = list(ds["validation"]["smiles"])
77
  val_features, val_removed_idxs = preprocess_molecules(
78
  val_smiles,
79
+ load_ecdf_path=args.save_path_ecdfs,
80
+ load_scaler_path=args.save_path_scaler,
81
  )
82
 
83
  model = Tox21RFClassifier(seed=42)
 
91
  task, train_features[feature_mask], task_labels[label_mask].astype(int)
92
  )
93
 
94
+ print(f"Save model under {args.save_path_model}")
95
+ model.save_model(args.save_path_model)
96
 
97
  print("Evaluate model")
98
  results = {}
 
110
 
111
 
112
  if __name__ == "__main__":
113
+ args = parser.parse_args()
114
+ main(args)