antoniaebner commited on
Commit
1994acc
·
1 Parent(s): 9b322e1

refactoring of feature preprocessing

Browse files
Files changed (7) hide show
  1. config/config.json +27 -8
  2. predict.py +27 -22
  3. preprocess.py +23 -145
  4. src/model.py +13 -87
  5. src/preprocess.py +330 -99
  6. src/utils.py +62 -2
  7. train.py +55 -51
config/config.json CHANGED
@@ -1,14 +1,33 @@
1
  {
2
  "seed": 0,
3
- "ecfp_radius": 3,
4
- "ecfp_fpsize": 8192,
5
- "feature_minvar": 0.01,
6
- "feature_maxcorr": 0.95,
7
- "model_path": "checkpoints/rf_alltasks.joblib",
8
- "data_folder": "data/",
9
  "log_folder": "logs/",
10
- "debug": "false",
11
- "task_configs": {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  "NR-AR": {
13
  "max_depth": "none",
14
  "max_features": "sqrt",
 
1
  {
2
  "seed": 0,
3
+ "debug": "false",
4
+ "device": "cpu",
5
+
 
 
 
6
  "log_folder": "logs/",
7
+
8
+ "data_folder": "data/",
9
+ "cvfold": 4,
10
+ "ecfp" : {
11
+ "radius": 3,
12
+ "fpsize": 8192
13
+ },
14
+ "descriptors": ["ecfps", "tox", "maccs", "rdkit_descrs"],
15
+ "feature_selection": {
16
+ "use": "true",
17
+ "min_var": 0.01,
18
+ "max_corr": 0.95,
19
+ "feature_keys": ["ecfps", "tox", "maccs", "rdkit_descrs"],
20
+ "max_features": -1
21
+ },
22
+ "feature_quantilization": {
23
+ "use": "true",
24
+ "feature_keys": ["rdkit_descrs"]
25
+ },
26
+ "max_samples": -1,
27
+ "scaler": "standard",
28
+
29
+ "ckpt_path": "checkpoints/rf_alltasks.joblib",
30
+ "model_configs": {
31
  "NR-AR": {
32
  "max_depth": "none",
33
  "max_features": "sqrt",
predict.py CHANGED
@@ -6,15 +6,18 @@ SMILES and target names as keys.
6
 
7
  # ---------------------------------------------------------------------------------------
8
  # Dependencies
 
 
9
  from collections import defaultdict
10
 
11
- import json
12
  import numpy as np
13
  from tqdm import tqdm
14
 
15
- from src.preprocess import create_descriptors
16
- from src.utils import TASKS, normalize_config
17
  from src.model import Tox21RFClassifier
 
 
 
18
 
19
  # ---------------------------------------------------------------------------------------
20
  CONFIG_FILE = "./config/config.json"
@@ -35,20 +38,29 @@ def predict(
35
  print(f"Received {len(smiles_list)} SMILES strings")
36
 
37
  with open(CONFIG_FILE, "r") as f:
38
- cfg = json.load(f)
39
- cfg = normalize_config(cfg)
40
 
41
  features, is_clean = create_descriptors(
42
- smiles_list, radius=cfg["ecfp_radius"], fpsize=cfg["ecfp_fpsize"]
43
  )
44
- n_clean_mols, n_feats = features.shape
45
- print(f"Created {n_feats} descriptors for {n_clean_mols} molecules.")
46
  print(f"{len(is_clean) - sum(is_clean)} molecules removed during cleaning")
47
 
48
  # setup model
49
  model = Tox21RFClassifier()
50
- model.load_model(cfg["model_path"])
51
- print(f"Loaded model from {cfg['model_path']}")
 
 
 
 
 
 
 
 
 
 
52
 
53
  # make predicitons
54
  predictions = defaultdict(dict)
@@ -56,24 +68,17 @@ def predict(
56
  print(f"Create predictions:")
57
  preds = []
58
  for target in tqdm(TASKS):
59
- X = features.copy()
60
- preds = np.empty_like(is_clean, dtype=np.float64)
 
61
 
 
62
  preds[~is_clean] = default_prediction
63
  preds[is_clean] = model.predict(target, X)
64
 
65
  for smiles, pred in zip(smiles_list, preds):
66
  predictions[smiles][target] = float(pred)
67
- if cfg["debug"]:
68
  break
69
 
70
  return predictions
71
-
72
-
73
- # from hiddens.testing import test_eval
74
-
75
- # with open(CONFIG_FILE, "r") as f:
76
- # config = json.load(f)
77
- # config = normalize_config(config)
78
-
79
- # test_eval(predict, debug=config["debug"], use_only_clean=False, use_only_first=False)
 
6
 
7
  # ---------------------------------------------------------------------------------------
8
  # Dependencies
9
+ import json
10
+ import copy
11
  from collections import defaultdict
12
 
13
+ import joblib
14
  import numpy as np
15
  from tqdm import tqdm
16
 
 
 
17
  from src.model import Tox21RFClassifier
18
+ from src.preprocess import create_descriptors, FeaturePreprocessor
19
+ from src.utils import TASKS, normalize_config
20
+
21
 
22
  # ---------------------------------------------------------------------------------------
23
  CONFIG_FILE = "./config/config.json"
 
38
  print(f"Received {len(smiles_list)} SMILES strings")
39
 
40
  with open(CONFIG_FILE, "r") as f:
41
+ config = json.load(f)
42
+ config = normalize_config(config)
43
 
44
  features, is_clean = create_descriptors(
45
+ smiles_list, config["descriptors"], **config["ecfp"]
46
  )
47
+ print(f"Created descriptors for {sum(is_clean)} molecules.")
 
48
  print(f"{len(is_clean) - sum(is_clean)} molecules removed during cleaning")
49
 
50
  # setup model
51
  model = Tox21RFClassifier()
52
+ preprocessor = FeaturePreprocessor(
53
+ feature_selection_config=config["feature_selection"],
54
+ feature_quantilization_config=config["feature_quantilization"],
55
+ descriptors=config["descriptors"],
56
+ max_samples=config["max_samples"],
57
+ scaler=config["scaler"],
58
+ )
59
+
60
+ ckpt = joblib.load(config["ckpt_path"])
61
+ model.set_state(ckpt["models"])
62
+ preprocessor.__setstate__(ckpt["preprocessor"])
63
+ print(f"Loaded model & preprocessor from {config['ckpt_path']}")
64
 
65
  # make predicitons
66
  predictions = defaultdict(dict)
 
68
  print(f"Create predictions:")
69
  preds = []
70
  for target in tqdm(TASKS):
71
+ X = copy.deepcopy(features)
72
+ X = {descr: array[is_clean] for descr, array in X.items()}
73
+ X = preprocessor.transform(X)
74
 
75
+ preds = np.empty_like(is_clean, dtype=np.float64)
76
  preds[~is_clean] = default_prediction
77
  preds[is_clean] = model.predict(target, X)
78
 
79
  for smiles, pred in zip(smiles_list, preds):
80
  predictions[smiles][target] = float(pred)
81
+ if config["debug"]:
82
  break
83
 
84
  return predictions
 
 
 
 
 
 
 
 
 
preprocess.py CHANGED
@@ -7,186 +7,64 @@ SMILES and target names as keys.
7
  """
8
 
9
  import os
 
10
  import argparse
11
 
12
  import numpy as np
13
 
14
  from src.preprocess import create_descriptors, get_tox21_split
15
- from src.utils import (
16
- TASKS,
17
- HF_TOKEN,
18
- create_dir,
19
- )
20
 
21
  parser = argparse.ArgumentParser(
22
  description="Data preprocessing script for the Tox21 dataset"
23
  )
24
 
25
  parser.add_argument(
26
- "--save_folder",
27
- type=str,
28
- default="data/",
29
- help="Folder to which preprocessed the data CSV and NPZ files should be saved.",
30
- )
31
-
32
- parser.add_argument(
33
- "--cv_fold",
34
- type=int,
35
- default=4,
36
- help="Select fold used as validation set.",
37
- )
38
-
39
- parser.add_argument(
40
- "--feature_selection",
41
- type=int,
42
- default=1,
43
- help="True (=1) to use feature selection.",
44
- )
45
-
46
- parser.add_argument(
47
- "--feature_selection_path",
48
- type=str,
49
- default="feat_selection.npz",
50
- help="Filename for saving feature selections.",
51
- )
52
-
53
- parser.add_argument(
54
- "--min_var",
55
- type=float,
56
- default=0.01,
57
- help="Minimum variance threshold for selecting features.",
58
- )
59
-
60
- parser.add_argument(
61
- "--max_corr",
62
- type=float,
63
- default=0.95,
64
- help="Maximum correlation threshold for selecting features.",
65
- )
66
-
67
- parser.add_argument(
68
- "--ecdfs_path",
69
  type=str,
70
- default="ecdfs.pkl",
71
- help="Filename to save ECDFs.",
72
  )
73
 
74
- parser.add_argument(
75
- "--ecfps_radius",
76
- type=int,
77
- default=3,
78
- help="Radius used for creating ECFPs.",
79
- )
80
 
81
- parser.add_argument(
82
- "--ecfps_folds",
83
- type=int,
84
- default=8192,
85
- help="Folds used for creating ECFPs.",
86
- )
87
 
88
- parser.add_argument(
89
- "--ecdfs",
90
- type=int,
91
- default=1,
92
- help="True (=1) to use ECDFs for creating quantiles of the RDKit descriptors.",
93
- )
94
-
95
-
96
- def main(args):
97
- """Preprocessing train/val data to use for TabPFN.
98
-
99
- 1. Download Tox21 train/val data from HF
100
- 2. Preprocess dataset splits
101
- """
102
- ds = get_tox21_split(HF_TOKEN, cvfold=args.cv_fold)
103
-
104
- feature_creation_kwargs = {
105
- "radius": args.ecfps_radius,
106
- "fpsize": args.ecfps_folds,
107
- "min_var": args.min_var,
108
- "max_corr": args.max_corr,
109
- }
110
- removed_mols = 0
111
-
112
- splits = ["train", "validation", "test"]
113
  for split in splits:
114
 
115
  print(f"Preprocess {split} molecules")
116
 
117
- if split != "test":
118
- ds_split = ds[split]
119
- smiles = list(ds_split["smiles"])
120
- else:
121
- import pandas as pd
122
-
123
- ds_split = pd.read_csv("data/tox21_test_cv4.csv")
124
-
125
- smiles = ds_split["smiles"]
126
-
127
- features, clean_mol_mask = create_descriptors(smiles, **feature_creation_kwargs)
128
-
129
- # if split == "train":
130
- # output = create_descriptors(
131
- # smiles,
132
- # return_feature_selection=True,
133
- # return_ecdfs=True,
134
- # **feature_creation_kwargs,
135
- # )
136
- # features = output.pop("features")
137
-
138
- # if args.feature_selection:
139
- # feature_selection = output.pop("feature_selection")
140
- # np.savez(
141
- # args.feature_selection_path,
142
- # ecfps_selec=feature_selection["ecfps_selec"],
143
- # tox_selec=feature_selection["tox_selec"],
144
- # )
145
-
146
- # print(f"Saved feature selection under {args.feature_selection_path}")
147
-
148
- # if args.ecdfs:
149
- # ecdfs = output.pop("ecdfs")
150
- # write_pickle(args.ecdfs_path, ecdfs)
151
- # print(f"Saved ECDFs under {args.ecdfs_path}")
152
-
153
- # else:
154
- # features = create_descriptors(
155
- # smiles,
156
- # ecdfs=ecdfs,
157
- # feature_selection=feature_selection,
158
- # **feature_creation_kwargs,
159
- # )["features"]
160
- removed_mols += (~clean_mol_mask).sum()
161
 
162
  labels = []
163
  for task in TASKS:
164
  labels.append(ds_split[task].to_numpy())
165
  labels = np.stack(labels, axis=1)
166
 
167
- save_path = os.path.join(args.save_folder, f"tox21_{split}_cv4.npz")
168
  with open(save_path, "wb") as f:
169
  np.savez(
170
  f,
171
- labels=labels[clean_mol_mask, :],
172
- features=features,
173
- # **features,
174
  )
175
  print(f"Saved preprocessed {split} split under {save_path}")
176
- print(f"{removed_mols} mols were removed during cleaning across all datasets")
177
  print("Preprocessing finished successfully")
178
 
179
 
180
  if __name__ == "__main__":
181
  args = parser.parse_args()
182
 
183
- # args.ecdfs_path = os.path.join(args.save_folder, args.ecdfs_path)
184
- # args.feature_selection_path = os.path.join(
185
- # args.save_folder, args.feature_selection_path
186
- # )
187
-
188
- create_dir(args.save_folder)
189
- # create_dir(args.ecdfs_path, is_file=True)
190
- # create_dir(args.feature_selection_path, is_file=True)
191
 
192
- main(args)
 
 
7
  """
8
 
9
  import os
10
+ import json
11
  import argparse
12
 
13
  import numpy as np
14
 
15
  from src.preprocess import create_descriptors, get_tox21_split
16
+ from src.utils import TASKS, HF_TOKEN, create_dir, normalize_config
 
 
 
 
17
 
18
  parser = argparse.ArgumentParser(
19
  description="Data preprocessing script for the Tox21 dataset"
20
  )
21
 
22
  parser.add_argument(
23
+ "--config",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  type=str,
25
+ default="config/config.json",
 
26
  )
27
 
 
 
 
 
 
 
28
 
29
+ def main(config):
30
+ """Create molecule descriptors for HF Tox21 dataset"""
31
+ ds = get_tox21_split(HF_TOKEN, cvfold=config["cvfold"])
 
 
 
32
 
33
+ splits = ["train", "validation"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  for split in splits:
35
 
36
  print(f"Preprocess {split} molecules")
37
 
38
+ ds_split = ds[split]
39
+ smiles = list(ds_split["smiles"])
40
+
41
+ features, clean_mol_mask = create_descriptors(
42
+ smiles, config["descriptors"], **config["ecfp"]
43
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  labels = []
46
  for task in TASKS:
47
  labels.append(ds_split[task].to_numpy())
48
  labels = np.stack(labels, axis=1)
49
 
50
+ save_path = os.path.join(config["data_folder"], f"tox21_{split}_cv4.npz")
51
  with open(save_path, "wb") as f:
52
  np.savez(
53
  f,
54
+ clean_mol_mask=clean_mol_mask,
55
+ labels=labels,
56
+ **features,
57
  )
58
  print(f"Saved preprocessed {split} split under {save_path}")
 
59
  print("Preprocessing finished successfully")
60
 
61
 
62
  if __name__ == "__main__":
63
  args = parser.parse_args()
64
 
65
+ with open(args.config, "r") as f:
66
+ config = json.load(f)
67
+ config = normalize_config(config)
 
 
 
 
 
68
 
69
+ create_dir(config["data_folder"])
70
+ main(config)
src/model.py CHANGED
@@ -6,15 +6,9 @@ SMILES and target names as keys.
6
 
7
  # ---------------------------------------------------------------------------------------
8
  # Dependencies
9
- import os
10
- import joblib
11
-
12
  import numpy as np
13
-
14
  from sklearn.ensemble import RandomForestClassifier
15
- from sklearn.preprocessing import StandardScaler
16
 
17
- from .preprocess import get_feature_selection, get_ecdfs, create_quantiles
18
  from .utils import TASKS
19
 
20
 
@@ -22,100 +16,34 @@ from .utils import TASKS
22
  class Tox21RFClassifier:
23
  """A random forest classifier that assigns a toxicity score to a given SMILES string."""
24
 
25
- def __init__(
26
- self, seed: int = 42, task_config: dict = None, rdkit_desc_idxs: list[int] = []
27
- ):
28
  """Initialize a random forest classifier for each of the 12 Tox21 tasks.
29
 
30
  Args:
31
  seed (int, optional): seed for RF to ensure reproducibility. Defaults to 42.
32
  """
33
  self.tasks = TASKS
34
- self.rdkit_desc_idxs = rdkit_desc_idxs
35
 
36
  self.models = {
37
  task: RandomForestClassifier(
38
  random_state=seed,
39
  n_jobs=8,
40
- **(
41
- {"n_estimators": 1000} if task_config is None else task_config[task]
42
- ),
43
  )
44
  for task in self.tasks
45
  }
46
- self.feature_selection = None
47
- self.ecdfs = None
48
- self.scaler = StandardScaler()
49
-
50
- def load_model(self, path: str) -> None:
51
- """Loads the model from a given path
52
-
53
- Args:
54
- path (str): path to model checkpoint
55
- """
56
- model = joblib.load(path)
57
-
58
- self.models = model["models"]
59
- self.scaler = model["scalers"]
60
- self.rdkit_desc_idxs = model["rdkit_desc_idxs"]
61
-
62
- self.feature_selection = model["feature_selections"]
63
- self.ecdfs = model["ecdfs"]
64
 
65
- def save_model(self, path: str) -> None:
66
- """Saves the model to a given path
67
 
68
  Args:
69
- path (str): path to save model to
70
  """
71
- if not os.path.exists(os.path.dirname(path)):
72
- os.makedirs(os.path.dirname(path))
73
-
74
- model = {
75
- "models": self.models,
76
- "feature_selections": self.feature_selection,
77
- "ecdfs": self.ecdfs,
78
- "scalers": self.scaler,
79
- "rdkit_desc_idxs": self.rdkit_desc_idxs,
80
- }
81
 
82
- joblib.dump(model, path)
83
-
84
- def fit_preprocessing(self, X: np.ndarray, min_var=0.01, max_corr=0.95) -> None:
85
- X_ = X.copy()
86
-
87
- _, n_feat = X.shape
88
-
89
- if self.rdkit_desc_idxs is None:
90
- self.rdkit_desc_idxs = np.arange(n_feat)
91
- else:
92
- assert (
93
- self.rdkit_desc_idxs < n_feat
94
- ).all(), "passed to_adapt list contains more features than in X!"
95
-
96
- self.ecdfs = get_ecdfs(X_[:, self.rdkit_desc_idxs])
97
- X_[:, self.rdkit_desc_idxs] = create_quantiles(
98
- X_[:, self.rdkit_desc_idxs], self.ecdfs
99
- )
100
-
101
- # get feature selection
102
- self.feature_selection = get_feature_selection(
103
- X_, min_var=min_var, max_corr=max_corr
104
- )
105
- X_ = X_[:, self.feature_selection]
106
-
107
- # fit scaler
108
- X_ = self.scaler.fit(X_)
109
-
110
- def _preprocess(self, X: np.ndarray) -> None:
111
- X_ = X.copy()
112
-
113
- X_[:, self.rdkit_desc_idxs] = create_quantiles(
114
- X_[:, self.rdkit_desc_idxs], self.ecdfs
115
- )
116
- X_ = X_[:, self.feature_selection]
117
- X_ = self.scaler.transform(X_)
118
- return X_
119
 
120
  def fit(self, task: str, X: np.ndarray, y: np.ndarray) -> None:
121
  """Train the random forest for a given task
@@ -126,9 +54,8 @@ class Tox21RFClassifier:
126
  y (np.ndarray): training labels
127
  """
128
  assert task in self.tasks, f"Unknown task: {task}"
129
-
130
- X_ = self._preprocess(X)
131
- self.models[task].fit(X_, y)
132
 
133
  def predict(self, task: str, X: np.ndarray) -> np.ndarray:
134
  """Predicts labels for a given Tox21 target using molecule features
@@ -144,6 +71,5 @@ class Tox21RFClassifier:
144
  assert (
145
  len(X.shape) == 2
146
  ), f"Function expects 2D np.array. Current shape: {X.shape}"
147
-
148
- X_ = self._preprocess(X)
149
- return self.models[task].predict_proba(X_)[:, 1]
 
6
 
7
  # ---------------------------------------------------------------------------------------
8
  # Dependencies
 
 
 
9
  import numpy as np
 
10
  from sklearn.ensemble import RandomForestClassifier
 
11
 
 
12
  from .utils import TASKS
13
 
14
 
 
16
  class Tox21RFClassifier:
17
  """A random forest classifier that assigns a toxicity score to a given SMILES string."""
18
 
19
+ def __init__(self, seed: int = 42, config: dict = None):
 
 
20
  """Initialize a random forest classifier for each of the 12 Tox21 tasks.
21
 
22
  Args:
23
  seed (int, optional): seed for RF to ensure reproducibility. Defaults to 42.
24
  """
25
  self.tasks = TASKS
 
26
 
27
  self.models = {
28
  task: RandomForestClassifier(
29
  random_state=seed,
30
  n_jobs=8,
31
+ **({"n_estimators": 1000} if config is None else config[task]),
 
 
32
  )
33
  for task in self.tasks
34
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ def set_state(self, state: dict) -> None:
37
+ """Sets the state of the model
38
 
39
  Args:
40
+ state (dict): models state dict
41
  """
42
+ self.models = state
 
 
 
 
 
 
 
 
 
43
 
44
+ def get_state(self) -> None:
45
+ """Return model state dict"""
46
+ return {"models": self.models}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  def fit(self, task: str, X: np.ndarray, y: np.ndarray) -> None:
49
  """Train the random forest for a given task
 
54
  y (np.ndarray): training labels
55
  """
56
  assert task in self.tasks, f"Unknown task: {task}"
57
+ _X, _y = X.copy(), y.copy()
58
+ self.models[task].fit(_X, _y)
 
59
 
60
  def predict(self, task: str, X: np.ndarray) -> np.ndarray:
61
  """Predicts labels for a given Tox21 target using molecule features
 
71
  assert (
72
  len(X.shape) == 2
73
  ), f"Function expects 2D np.array. Current shape: {X.shape}"
74
+ _X = X.copy()
75
+ return self.models[task].predict_proba(_X)[:, 1]
 
src/preprocess.py CHANGED
@@ -6,20 +6,304 @@ As an input it takes a list of SMILES and it outputs a nested dictionary with
6
  SMILES and target names as keys.
7
  """
8
 
 
9
  import json
 
10
 
11
  import numpy as np
12
  import pandas as pd
13
 
14
  from datasets import load_dataset
 
15
  from sklearn.feature_selection import VarianceThreshold
 
16
  from statsmodels.distributions.empirical_distribution import ECDF
17
 
18
  from rdkit import Chem, DataStructs
19
  from rdkit.Chem import Descriptors, rdFingerprintGenerator, MACCSkeys
20
  from rdkit.Chem.rdchem import Mol
21
 
22
- from .utils import USED_200_DESCR, TOX_SMARTS_PATH, Standardizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
  def create_cleaned_mol_objects(smiles: list[str]) -> tuple[list[Mol], np.ndarray]:
@@ -198,112 +482,59 @@ def fill(features, mask, value=np.nan):
198
 
199
  def create_descriptors(
200
  smiles,
201
- ecdfs=None,
202
- feature_selection=None,
203
- return_ecdfs=False,
204
- return_feature_selection=False,
205
- **kwargs,
206
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  # Create cleanded rdkit mol objects
208
  mols, clean_mol_mask = create_cleaned_mol_objects(smiles)
209
- print("Cleaned molecules")
210
-
211
- tox_patterns = get_tox_patterns(TOX_SMARTS_PATH)
212
 
213
  # Create fingerprints and descriptors
214
- ecfps = create_ecfp_fps(mols, **kwargs)
215
- # expand using mol_mask
216
- # ecfps = fill(ecfps, ~clean_mol_mask)
217
- print("Created ECFP fingerprints")
218
- # print("ecfps features:", ecfps.shape)
219
-
220
- tox = create_tox_features(mols, tox_patterns)
221
- # tox = fill(tox, ~clean_mol_mask)
222
- print("Created Tox features")
223
- # print("tox features:", tox.shape)
224
-
225
- # Create and save feature selection for ecfps and tox
226
- # if feature_selection is None:
227
- # print("Create Feature selection")
228
- # ecfps_selec = get_feature_selection(ecfps, **kwargs)
229
- # tox_selec = get_feature_selection(tox, **kwargs)
230
- # feature_selection = {"ecfps_selec": ecfps_selec, "tox_selec": tox_selec}
231
-
232
- # else:
233
- # ecfps_selec = feature_selection["ecfps_selec"]
234
- # tox_selec = feature_selection["tox_selec"]
235
-
236
- # ecfps = ecfps[:, ecfps_selec]
237
- # tox = tox[:, tox_selec]
238
-
239
- maccs = create_maccs_keys(mols)
240
- # maccs = fill(maccs, ~clean_mol_mask)
241
- print("Created MACCS keys")
242
-
243
- rdkit_descrs = create_rdkit_descriptors(mols)
244
- # rdkit_descrs = fill(rdkit_descrs, ~clean_mol_mask)
245
- print("Created RDKit descriptors")
246
-
247
- # # Create and save ecdfs
248
- # if ecdfs is None:
249
- # print("Create ECDFs")
250
- # ecdfs = []
251
- # for column in range(rdkit_descrs.shape[1]):
252
- # raw_values = rdkit_descrs[:, column].reshape(-1)
253
- # ecdfs.append(ECDF(raw_values))
254
-
255
- # # Create quantiles
256
- # rdkit_descr_quantiles = create_quantiles(rdkit_descrs, ecdfs)
257
- # # expand using mol_mask
258
- # rdkit_descr_quantiles = fill(rdkit_descr_quantiles, ~clean_mol_mask)
259
- # print("Created quantiles of RDKit descriptors")
260
 
261
  # concatenate features
262
- # features = {
263
- # "ecfps": ecfps,
264
- # "tox": tox,
265
- # "maccs": maccs,
266
- # "rdkit_descr_quantiles": rdkit_descr_quantiles,
267
- # }
268
- # for feat in [ecfps, tox, maccs, rdkit_descrs]:
269
- # print(feat.shape)
270
- features = np.concat((ecfps, tox, maccs, rdkit_descrs), axis=1)
271
- # return_dict = {"features": features}
272
- # if return_ecdfs:
273
- # return_dict["ecdfs"] = ecdfs
274
- # if return_feature_selection:
275
- # return_dict["feature_selection"] = feature_selection
276
- return features, clean_mol_mask
277
 
278
-
279
- def get_ecdfs(raw_features: np.ndarray, **kwargs) -> np.ndarray:
280
- ecdfs = []
281
- for column in range(raw_features.shape[1]):
282
- raw_values = raw_features[:, column].reshape(-1)
283
- ecdfs.append(ECDF(raw_values))
284
- return ecdfs
285
-
286
-
287
- def get_feature_selection(
288
- raw_features: np.ndarray, min_var=0.01, max_corr=0.95, **kwargs
289
- ) -> np.ndarray:
290
- # select features with at least min_var variation
291
- var_thresh = VarianceThreshold(threshold=min_var)
292
- feature_selection = var_thresh.fit(raw_features).get_support(indices=True)
293
-
294
- n_features_preselected = len(feature_selection)
295
-
296
- # Remove highly correlated features
297
- corr_matrix = np.corrcoef(raw_features[:, feature_selection], rowvar=False)
298
- upper_tri = np.triu(corr_matrix, k=1)
299
- to_keep = np.ones((n_features_preselected,), dtype=bool)
300
- for i in range(upper_tri.shape[0]):
301
- for j in range(upper_tri.shape[1]):
302
- if upper_tri[i, j] > max_corr:
303
- to_keep[j] = False
304
-
305
- feature_selection = feature_selection[to_keep]
306
- return feature_selection
307
 
308
 
309
  def get_tox21_split(token, cvfold=None):
 
6
  SMILES and target names as keys.
7
  """
8
 
9
+ import copy
10
  import json
11
+ from typing import Any
12
 
13
  import numpy as np
14
  import pandas as pd
15
 
16
  from datasets import load_dataset
17
+ from sklearn.base import BaseEstimator, TransformerMixin
18
  from sklearn.feature_selection import VarianceThreshold
19
+ from sklearn.preprocessing import StandardScaler, FunctionTransformer
20
  from statsmodels.distributions.empirical_distribution import ECDF
21
 
22
  from rdkit import Chem, DataStructs
23
  from rdkit.Chem import Descriptors, rdFingerprintGenerator, MACCSkeys
24
  from rdkit.Chem.rdchem import Mol
25
 
26
+ from .utils import USED_200_DESCR, TOX_SMARTS_PATH, Standardizer, FeatureDictMixin
27
+
28
+
29
+ class SquashScaler(TransformerMixin, BaseEstimator):
30
+ """
31
+ Scaler that performs sequential standardization, nonlinearity (tanh), and
32
+ re-standardization. Inspired by DeepTox (Mayr et al., 2016)
33
+ """
34
+
35
+ def __init__(self):
36
+ self.scaler1 = StandardScaler()
37
+ self.scaler2 = StandardScaler()
38
+
39
+ def fit(self, X):
40
+ _X = X.copy()
41
+ _X = self.scaler1.fit_transform(_X)
42
+ _X = np.tanh(_X)
43
+ _X = self.scaler2.fit(_X)
44
+ return self
45
+
46
+ def transform(self, X):
47
+ _X = X.copy()
48
+ _X = self.scaler1.transform(_X)
49
+ _X = np.tanh(_X)
50
+ return self.scaler2.transform(_X)
51
+
52
+
53
+ SCALER_REGISTRY = {
54
+ "none": FunctionTransformer,
55
+ "standard": StandardScaler,
56
+ "squash": SquashScaler,
57
+ }
58
+
59
+
60
+ class SubSampler(TransformerMixin, BaseEstimator):
61
+ """
62
+ Preprocessor that randomly samples `max_samples` from data.
63
+
64
+ Args:
65
+ max_samples (int): Maximum allowed samples. If -1, all samples are retained.
66
+
67
+ Input:
68
+ np.ndarray: A 2D NumPy array of shape (n_samples, n_features).
69
+
70
+ Output:
71
+ np.ndarray: Subsampled array of shape (min(n_samples, max_samples), n_features).
72
+ """
73
+
74
+ def __init__(self, *, max_samples=-1):
75
+ self.max_samples = max_samples
76
+ self.is_fitted_ = True
77
+
78
+ def fit(self, X: np.ndarray, y: np.ndarray | None = None):
79
+ return self
80
+
81
+ def transform(
82
+ self, X: np.ndarray, y: np.ndarray | None = None
83
+ ) -> np.ndarray | tuple[np.ndarray]:
84
+
85
+ _X = X.copy()
86
+ _y = y.copy() if y is not None else None
87
+
88
+ if self.max_samples > 0:
89
+ resample_idxs = np.random.choice(
90
+ np.arange(_X.shape[0]), size=(self.max_samples,), replace=True
91
+ )
92
+ _X = _X[resample_idxs]
93
+ _y = _y[resample_idxs] if _y is not None else None
94
+
95
+ if _y is None:
96
+ return _X
97
+ return _X, _y
98
+
99
+
100
+ class FeatureSelector(FeatureDictMixin, TransformerMixin, BaseEstimator):
101
+ """
102
+ Preprocessor that performs feature selection based on variance and correlation.
103
+
104
+ This transformer selects features that:
105
+ 1. Have variance above a specified threshold.
106
+ 2. Are below a given pairwise correlation threshold.
107
+ 3. Among the remaining features, keeps only the top `max_features` with the highest variance.
108
+
109
+ The input and output are both dictionaries mapping feature types to their corresponding
110
+ feature matrices.
111
+
112
+ Args:
113
+ min_var (float): Minimum variance required for a feature to be retained.
114
+ max_corr (float): Maximum allowed correlation between features.
115
+ Features exceeding this threshold with others are removed.
116
+ max_features (int): Maximum number of features to keep after filtering.
117
+ If -1, all remaining features are retained.
118
+
119
+ Input:
120
+ dict[str, np.ndarray]: A dictionary where each key corresponds to a feature type
121
+ and each value is a 2D NumPy array of shape (n_samples, n_features).
122
+
123
+ Output:
124
+ dict[str, np.ndarray]: A dictionary with the same keys as the input,
125
+ containing only the selected features for each feature type.
126
+ """
127
+
128
+ def __init__(
129
+ self, *, min_var=0.0, max_corr=1.0, max_features=-1, feature_keys=None
130
+ ):
131
+ self.min_var = min_var
132
+ self.max_corr = max_corr
133
+ self.max_features = max_features
134
+ self._feature_mask = None
135
+
136
+ super().__init__(feature_keys=feature_keys)
137
+
138
+ def fit(self, X: dict[str, np.ndarray]):
139
+ _X = self.dict_to_array(X)
140
+
141
+ # select features with at least min_var variation
142
+ if self.min_var > 0.0:
143
+ var_thresh = VarianceThreshold(threshold=self.min_var)
144
+ feature_mask = var_thresh.fit(_X).get_support() # mask
145
+
146
+ # select features with at least max_var variation
147
+ if self.max_corr < 1.0:
148
+ corr_matrix = np.corrcoef(_X[:, feature_mask], rowvar=False)
149
+ upper_tri = np.triu(corr_matrix, k=1)
150
+ to_keep = np.ones((sum(feature_mask),), dtype=bool)
151
+ for i in range(upper_tri.shape[0]):
152
+ for j in range(upper_tri.shape[1]):
153
+ if upper_tri[i, j] > self.max_corr:
154
+ to_keep[j] = False
155
+
156
+ feature_mask[feature_mask] = to_keep
157
+
158
+ if self.max_features == 0:
159
+ raise ValueError(
160
+ f"max_features (={self.max_features}) must be -1 or larger 0."
161
+ )
162
+ elif self.max_features > 0:
163
+ # select features with at least max_var variation
164
+ feature_vars = np.nanvar(_X[:, feature_mask], axis=0)
165
+ order = np.argsort(feature_vars)[: -(self.max_features + 1) : -1]
166
+ keep_feat_idx = np.arange(feature_mask)[order]
167
+ feature_mask = np.isin(
168
+ np.arange(feature_mask), keep_feat_idx, assume_unique=True
169
+ )
170
+
171
+ self._feature_mask = feature_mask
172
+ self.is_fitted_ = True
173
+ return self
174
+
175
+ def transform(self, X: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
176
+ _X = self.dict_to_array(X)
177
+ _X = _X[:, self._feature_mask]
178
+ self._curr_keys = self._curr_keys[self._feature_mask]
179
+ return self.array_to_dict(_X)
180
+
181
+
182
+ class QuantileCreator(FeatureDictMixin, TransformerMixin, BaseEstimator):
183
+ """
184
+ Preprocessor that transforms features into empirical quantiles using ECDFs.
185
+
186
+ This transformer applies an Empirical Cumulative Distribution Function (ECDF)
187
+ to each feature and replaces feature values with their corresponding quantile
188
+ ranks. The transformation is applied independently to each feature type.
189
+
190
+ Both input and output are dictionaries mapping feature types to their
191
+ corresponding feature matrices.
192
+
193
+ Input:
194
+ dict[str, np.ndarray]: A dictionary where each key corresponds to a feature type
195
+ and each value is a 2D NumPy array of shape (n_samples, n_features).
196
+
197
+ Output:
198
+ dict[str, np.ndarray]: A dictionary with the same keys as the input,
199
+ where each feature value is replaced by its corresponding ECDF quantile rank.
200
+ """
201
+
202
+ def __init__(self, *, feature_keys=None):
203
+ self._ecdfs = None
204
+ super().__init__(feature_keys=feature_keys)
205
+
206
+ def fit(self, X: dict[str, np.ndarray]):
207
+ _X = self.dict_to_array(X)
208
+ ecdfs = []
209
+ for column in range(_X.shape[1]):
210
+ raw_values = _X[:, column].reshape(-1)
211
+ ecdfs.append(ECDF(raw_values))
212
+ self._ecdfs = ecdfs
213
+ self.is_fitted_ = True
214
+ return self
215
+
216
+ def transform(self, X: dict[str, np.ndarray]) -> np.ndarray:
217
+ _X = self.dict_to_array(X)
218
+
219
+ quantiles = np.zeros_like(_X)
220
+ for column in range(_X.shape[1]):
221
+ raw_values = _X[:, column].reshape(-1)
222
+ ecdf = self._ecdfs[column]
223
+ q = ecdf(raw_values)
224
+ quantiles[:, column] = q
225
+
226
+ return self.array_to_dict(quantiles)
227
+
228
+
229
+ class FeaturePreprocessor(TransformerMixin, BaseEstimator):
230
+ """This class implements the feature preprocessing from a dictionary of molecule features."""
231
+
232
+ def __init__(
233
+ self,
234
+ feature_selection_config: dict[str, Any],
235
+ feature_quantilization_config: dict[str, Any],
236
+ descriptors: list[str],
237
+ max_samples: int = -1,
238
+ scaler: str = "standard",
239
+ ):
240
+ self.descriptors = descriptors
241
+
242
+ self.feature_quantilization_config = feature_quantilization_config
243
+ self.use_feat_quant = self.feature_quantilization_config.pop("use")
244
+ self.quantile_creator = QuantileCreator(**feature_quantilization_config)
245
+
246
+ self.feature_selection_config = feature_selection_config
247
+ self.use_feat_selec = self.feature_selection_config.pop("use")
248
+ self.feature_selector = FeatureSelector(**feature_selection_config)
249
+
250
+ self.max_samples = max_samples
251
+ self.sub_sampler = SubSampler(max_samples=max_samples)
252
+
253
+ self.scaler = SCALER_REGISTRY[scaler]()
254
+
255
+ def __getstate__(self):
256
+ state = super().__getstate__()
257
+ state["quantile_creator"] = self.quantile_creator.__getstate__()
258
+ state["feature_selector"] = self.feature_selector.__getstate__()
259
+ state["sub_sampler"] = self.sub_sampler.__getstate__()
260
+ state["scaler"] = self.scaler.__getstate__()
261
+ return state
262
+
263
+ def __setstate__(self, state):
264
+ _state = copy.deepcopy(state)
265
+ self.quantile_creator.__setstate__(_state.pop("quantile_creator"))
266
+ self.feature_selector.__setstate__(_state.pop("feature_selector"))
267
+ self.sub_sampler.__setstate__(_state.pop("sub_sampler"))
268
+ self.scaler.__setstate__(_state.pop("scaler"))
269
+ super().__setstate__(_state)
270
+
271
+ def fit(self, X: dict[str, np.ndarray]):
272
+ """Fit the processor transformers"""
273
+ _X = copy.deepcopy(X)
274
+
275
+ if self.use_feat_quant:
276
+ _X = self.quantile_creator.fit_transform(_X)
277
+
278
+ if self.use_feat_selec:
279
+ _X = self.feature_selector.fit_transform(_X)
280
+
281
+ _X = np.concatenate([_X[descr] for descr in self.descriptors], axis=1)
282
+ self.scaler.fit(_X)
283
+ return self
284
+
285
+ def transform(
286
+ self, X: np.ndarray, y: np.ndarray | None = None
287
+ ) -> np.ndarray | tuple[np.ndarray]:
288
+
289
+ _X = X.copy()
290
+ _y = y.copy() if y is not None else None
291
+
292
+ if self.use_feat_quant:
293
+ _X = self.quantile_creator.transform(_X)
294
+
295
+ if self.use_feat_selec:
296
+ _X = self.feature_selector.transform(_X)
297
+
298
+ _X = np.concatenate([_X[descr] for descr in self.descriptors], axis=1)
299
+ _X = self.scaler.transform(_X)
300
+
301
+ if _y is None:
302
+ _X = self.sub_sampler.transform(_X)
303
+ return _X
304
+
305
+ _X, _y = self.sub_sampler.transform(_X, _y)
306
+ return _X, _y
307
 
308
 
309
  def create_cleaned_mol_objects(smiles: list[str]) -> tuple[list[Mol], np.ndarray]:
 
482
 
483
  def create_descriptors(
484
  smiles,
485
+ descriptors,
486
+ **ecfp_kwargs,
 
 
 
487
  ):
488
+ """Generate molecular descriptors for multiple SMILES strings.
489
+
490
+ Each SMILES is processed and sanitized using RDKit.
491
+ SMILES that cannot be sanitized are encoded with NaNs, and a corresponding boolean mask
492
+ is returned to indicate which inputs were successfully processed.
493
+
494
+ Args:
495
+ smiles (list[str]): List of SMILES strings for which to generate descriptors.
496
+ descriptors (list[str]): List of descriptor types to compute.
497
+ Supported values include:
498
+ ['ecfps', 'tox', 'maccs', 'rdkit_descrs'].
499
+
500
+ Returns:
501
+ tuple[dict[str, np.ndarray], np.ndarray]:
502
+ - A dictionary mapping descriptor names to their computed arrays.
503
+ - A boolean mask of shape (len(smiles),) indicating which SMILES
504
+ were successfully sanitized and processed.
505
+ """
506
  # Create cleanded rdkit mol objects
507
  mols, clean_mol_mask = create_cleaned_mol_objects(smiles)
508
+ print(f"Cleaned molecules, {(~clean_mol_mask).sum()} could not be sanitized")
 
 
509
 
510
  # Create fingerprints and descriptors
511
+ if "ecfps" in descriptors:
512
+ ecfps = create_ecfp_fps(mols, **ecfp_kwargs)
513
+ ecfps = fill(ecfps, ~clean_mol_mask)
514
+ print("Created ECFP fingerprints")
515
+
516
+ if "tox" in descriptors:
517
+ tox_patterns = get_tox_patterns(TOX_SMARTS_PATH)
518
+ tox = create_tox_features(mols, tox_patterns)
519
+ tox = fill(tox, ~clean_mol_mask)
520
+ print("Created Tox features")
521
+
522
+ if "maccs" in descriptors:
523
+ maccs = create_maccs_keys(mols)
524
+ maccs = fill(maccs, ~clean_mol_mask)
525
+ print("Created MACCS keys")
526
+
527
+ if "rdkit_descrs" in descriptors:
528
+ rdkit_descrs = create_rdkit_descriptors(mols)
529
+ rdkit_descrs = fill(rdkit_descrs, ~clean_mol_mask)
530
+ print("Created RDKit descriptors")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
531
 
532
  # concatenate features
533
+ features = {}
534
+ for descr in descriptors:
535
+ features[descr] = vars()[descr]
 
 
 
 
 
 
 
 
 
 
 
 
536
 
537
+ return features, clean_mol_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
538
 
539
 
540
  def get_tox21_split(token, cvfold=None):
src/utils.py CHANGED
@@ -7,6 +7,9 @@
7
 
8
  import os
9
  import pickle
 
 
 
10
 
11
  from rdkit import Chem
12
  from rdkit.Chem.MolStandardize import rdMolStandardize
@@ -29,7 +32,7 @@ TASKS = [
29
  "SR-p53",
30
  ]
31
 
32
- KNOWN_DESCR = ["ecfps", "rdkit_descr_quantiles", "maccs", "tox"]
33
 
34
  USED_200_DESCR = [
35
  0,
@@ -433,6 +436,63 @@ class Standardizer:
433
  return mol_out, n_tautomers
434
 
435
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
  def load_pickle(path: str):
437
  with open(path, "rb") as file:
438
  content = pickle.load(file)
@@ -459,7 +519,7 @@ def normalize_config(config: dict):
459
  for key, val in config.items():
460
  if isinstance(val, dict):
461
  new_config[key] = normalize_config(val)
462
- elif val in mapping:
463
  new_config[key] = mapping[val]
464
  else:
465
  new_config[key] = val
 
7
 
8
  import os
9
  import pickle
10
+ from typing import Any
11
+
12
+ import numpy as np
13
 
14
  from rdkit import Chem
15
  from rdkit.Chem.MolStandardize import rdMolStandardize
 
32
  "SR-p53",
33
  ]
34
 
35
+ KNOWN_DESCR = ["ecfps", "tox", "maccs", "rdkit_descrs"]
36
 
37
  USED_200_DESCR = [
38
  0,
 
436
  return mol_out, n_tautomers
437
 
438
 
439
+ class FeatureDictMixin:
440
+ """
441
+ Mixin that enables bidirectional handling of dict-based multi-feature inputs.
442
+ Allows selective removal of columns directly from the combined array.
443
+
444
+ Example input:
445
+ {
446
+ "ecfps": np.ndarray,
447
+ "tox": np.ndarray,
448
+ }
449
+ """
450
+
451
+ def __init__(self, feature_keys=None):
452
+ self.feature_keys = feature_keys
453
+ self._curr_keys = None
454
+ self._unused_data = None
455
+
456
+ def dict_to_array(self, input: dict[Any, np.ndarray]) -> np.ndarray:
457
+ """Parse dict input and concatenate into a single array."""
458
+ if not isinstance(input, dict):
459
+ raise TypeError("Input must be a dict {feature_type: np.ndarray, ...}")
460
+
461
+ self._unused_data = {}
462
+ remaining_input = {}
463
+ for key in list(input.keys()):
464
+ if key not in self.feature_keys:
465
+ self._unused_data[key] = input[key]
466
+ else:
467
+ remaining_input[key] = input[key]
468
+
469
+ curr_keys = []
470
+ output = []
471
+ for key in self.feature_keys:
472
+ array = remaining_input.pop(key)
473
+ if array.ndim != 2:
474
+ raise ValueError(f"Feature '{key}' must be 2D, got shape {array.shape}")
475
+
476
+ curr_keys.extend([key] * array.shape[1])
477
+ output.append(array)
478
+
479
+ self._curr_keys = np.array(curr_keys)
480
+
481
+ return np.concatenate(output, axis=1)
482
+
483
+ def array_to_dict(self, input: np.ndarray) -> dict[Any, np.ndarray]:
484
+ """Reconstruct dict from a concatenated array."""
485
+ if self._curr_keys is None:
486
+ raise ValueError("No feature mapping stored. Did you call parse_input()?")
487
+
488
+ output = {key: input[:, self._curr_keys == key] for key in self.feature_keys}
489
+ output.update(self._unused_data)
490
+
491
+ self._curr_keys = None
492
+ self._unused_data = None
493
+ return output
494
+
495
+
496
  def load_pickle(path: str):
497
  with open(path, "rb") as file:
498
  content = pickle.load(file)
 
519
  for key, val in config.items():
520
  if isinstance(val, dict):
521
  new_config[key] = normalize_config(val)
522
+ elif isinstance(val, (int, float, str)) and val in mapping:
523
  new_config[key] = mapping[val]
524
  else:
525
  new_config[key] = val
train.py CHANGED
@@ -4,6 +4,7 @@ Script for fitting and saving any preprocessing assets, as well as the fitted RF
4
 
5
  import os
6
  import json
 
7
  import random
8
  import logging
9
  import argparse
@@ -12,11 +13,8 @@ import numpy as np
12
  from datetime import datetime
13
 
14
  from src.model import Tox21RFClassifier
15
- from src.utils import (
16
- create_dir,
17
- normalize_config,
18
- USED_200_DESCR,
19
- )
20
 
21
  parser = argparse.ArgumentParser(description="RF Training script for Tox21 dataset")
22
 
@@ -27,7 +25,7 @@ parser.add_argument(
27
  )
28
 
29
 
30
- def main(cfg):
31
  timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
32
 
33
  # setup logger
@@ -39,7 +37,7 @@ def main(cfg):
39
  handlers=[
40
  logging.FileHandler(
41
  os.path.join(
42
- cfg["log_folder"],
43
  f"{script_name}_{timestamp}.log",
44
  )
45
  ),
@@ -47,50 +45,50 @@ def main(cfg):
47
  ],
48
  )
49
 
50
- task_configs = cfg.pop("task_configs")
51
- logger.info(f"Config: {cfg}")
52
- task_configs_repr = "Task configs: \n" + "\n".join(
53
- [str(val) for key, val in task_configs.items()]
54
  )
55
-
56
- logger.info(f"Task configs: \n{task_configs_repr}")
57
 
58
  # seeding
59
- random.seed(cfg["seed"])
60
- np.random.seed(cfg["seed"])
61
-
62
- train_data = np.load(os.path.join(cfg["data_folder"], "tox21_train_cv4.npz"))
63
- train_X = train_data[
64
- "features"
65
- ] # np.concatenate([train_data[descr] for descr in KNOWN_DESCR], axis=1)
66
- train_y = train_data["labels"]
67
-
68
- val_data = np.load(os.path.join(cfg["data_folder"], "tox21_validation_cv4.npz"))
69
- val_X = val_data[
70
- "features"
71
- ] # np.concatenate([val_data[descr] for descr in KNOWN_DESCR], axis=1)
72
- val_y = val_data["labels"]
73
-
74
- data = np.concatenate([train_X, val_X], axis=0)
75
- labels = np.concatenate([train_y, val_y], axis=0)
76
- logger.info(f"Train data shape: {data.shape}")
77
-
78
- if cfg["model_path"]:
79
  logger.info(
80
- f"Fitted RandomForestClassifier will be saved as: {cfg['model_path']}"
81
  )
82
  else:
83
  logger.info("Fitted RandomForestClassifier will NOT be saved.")
84
 
85
- rdkit_descr_idxs = np.arange(data.shape[1] - len(USED_200_DESCR), data.shape[1])
86
- model = Tox21RFClassifier(
87
- seed=cfg["seed"],
88
- task_config=task_configs,
89
- rdkit_desc_idxs=rdkit_descr_idxs,
90
- )
91
- model.fit_preprocessing(
92
- data, min_var=cfg["feature_minvar"], max_corr=cfg["feature_maxcorr"]
 
93
  )
 
94
 
95
  logger.info("Start training.")
96
  for i, task in enumerate(model.tasks):
@@ -98,28 +96,34 @@ def main(cfg):
98
  label_mask = ~np.isnan(task_labels)
99
  logger.info(f"Fit task {task} using {sum(label_mask)} samples")
100
 
101
- task_data = data[label_mask]
102
  task_labels = task_labels[label_mask].astype(int)
103
 
 
104
  model.fit(task, task_data, task_labels)
105
- if cfg["debug"]:
106
  break
107
 
108
  log_text = f"Finished training."
109
  logger.info(log_text)
110
 
111
- if cfg["model_path"]:
112
- model.save_model(cfg["model_path"])
113
- logger.info(f"Save model as: {cfg['model_path']}")
 
 
 
 
 
114
 
115
 
116
  if __name__ == "__main__":
117
  args = parser.parse_args()
118
 
119
  with open(args.config, "r") as f:
120
- cfg = json.load(f)
121
- cfg = normalize_config(cfg)
122
 
123
- create_dir(cfg["log_folder"])
124
 
125
- main(cfg)
 
4
 
5
  import os
6
  import json
7
+ import joblib
8
  import random
9
  import logging
10
  import argparse
 
13
  from datetime import datetime
14
 
15
  from src.model import Tox21RFClassifier
16
+ from src.preprocess import Tox21Preprocessor
17
+ from src.utils import create_dir, normalize_config,
 
 
 
18
 
19
  parser = argparse.ArgumentParser(description="RF Training script for Tox21 dataset")
20
 
 
25
  )
26
 
27
 
28
+ def main(config):
29
  timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
30
 
31
  # setup logger
 
37
  handlers=[
38
  logging.FileHandler(
39
  os.path.join(
40
+ config["log_folder"],
41
  f"{script_name}_{timestamp}.log",
42
  )
43
  ),
 
45
  ],
46
  )
47
 
48
+ logger.info(f"Config: {config}")
49
+ model_configs_repr = "Model configs: \n" + "\n".join(
50
+ [str(val) for val in config["model_configs"].values()]
 
51
  )
52
+ logger.info(f"Model configs: \n{model_configs_repr}")
 
53
 
54
  # seeding
55
+ random.seed(config["seed"])
56
+ np.random.seed(config["seed"])
57
+
58
+ train_data = np.load(os.path.join(config["data_folder"], "tox21_train_cv4.npz"))
59
+ val_data = np.load(os.path.join(config["data_folder"], "tox21_validation_cv4.npz"))
60
+
61
+ # filter out unsanitized molecules
62
+ train_is_clean = train_data["clean_mol_mask"]
63
+ val_is_clean = val_data["clean_mol_mask"]
64
+ train_data = {descr: array[train_is_clean] for descr, array in train_data.items()}
65
+ val_data = {descr: array[val_is_clean] for descr, array in val_data.items()}
66
+
67
+ # combine datasets
68
+ data = {
69
+ descr: np.concatenate([train_data[descr], val_data[descr]], axis=0)
70
+ for descr in config["descriptors"]
71
+ }
72
+ labels = np.concatenate([train_data["labels"], val_data["labels"]], axis=0)
73
+
74
+ if config["ckpt_path"]:
75
  logger.info(
76
+ f"Fitted RandomForestClassifier will be saved as: {config['ckpt_path']}"
77
  )
78
  else:
79
  logger.info("Fitted RandomForestClassifier will NOT be saved.")
80
 
81
+ model = Tox21RFClassifier(seed=config["seed"], config=config["model_configs"])
82
+
83
+ # setup processors
84
+ preprocessor = Tox21Preprocessor(
85
+ feature_selection_config=config["feature_selection"],
86
+ feature_quantilization_config=config["feature_quantilization"],
87
+ descriptors=config["descriptors"],
88
+ max_samples=config["max_samples"],
89
+ scaler=config["scaler"],
90
  )
91
+ preprocessor.fit(data)
92
 
93
  logger.info("Start training.")
94
  for i, task in enumerate(model.tasks):
 
96
  label_mask = ~np.isnan(task_labels)
97
  logger.info(f"Fit task {task} using {sum(label_mask)} samples")
98
 
99
+ task_data = {key: val[label_mask] for key, val in data.items()}
100
  task_labels = task_labels[label_mask].astype(int)
101
 
102
+ task_data = preprocessor.transform(task_data)
103
  model.fit(task, task_data, task_labels)
104
+ if config["debug"]:
105
  break
106
 
107
  log_text = f"Finished training."
108
  logger.info(log_text)
109
 
110
+ if config["ckpt_path"]:
111
+ ckpt = {
112
+ "preprocessor": preprocessor.__getstate__(),
113
+ "models": model.get_state(),
114
+ }
115
+ # model.save_model(config["ckpt_path"])
116
+ joblib.dump(ckpt, config["ckpt_path"])
117
+ logger.info(f"Save model as: {config['ckpt_path']}")
118
 
119
 
120
  if __name__ == "__main__":
121
  args = parser.parse_args()
122
 
123
  with open(args.config, "r") as f:
124
+ config = json.load(f)
125
+ config = normalize_config(config)
126
 
127
+ create_dir(config["log_folder"])
128
 
129
+ main(config)