Spaces:
Sleeping
Sleeping
Commit
·
1994acc
1
Parent(s):
9b322e1
refactoring of feature preprocessing
Browse files- config/config.json +27 -8
- predict.py +27 -22
- preprocess.py +23 -145
- src/model.py +13 -87
- src/preprocess.py +330 -99
- src/utils.py +62 -2
- train.py +55 -51
config/config.json
CHANGED
|
@@ -1,14 +1,33 @@
|
|
| 1 |
{
|
| 2 |
"seed": 0,
|
| 3 |
-
"
|
| 4 |
-
"
|
| 5 |
-
|
| 6 |
-
"feature_maxcorr": 0.95,
|
| 7 |
-
"model_path": "checkpoints/rf_alltasks.joblib",
|
| 8 |
-
"data_folder": "data/",
|
| 9 |
"log_folder": "logs/",
|
| 10 |
-
|
| 11 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
"NR-AR": {
|
| 13 |
"max_depth": "none",
|
| 14 |
"max_features": "sqrt",
|
|
|
|
| 1 |
{
|
| 2 |
"seed": 0,
|
| 3 |
+
"debug": "false",
|
| 4 |
+
"device": "cpu",
|
| 5 |
+
|
|
|
|
|
|
|
|
|
|
| 6 |
"log_folder": "logs/",
|
| 7 |
+
|
| 8 |
+
"data_folder": "data/",
|
| 9 |
+
"cvfold": 4,
|
| 10 |
+
"ecfp" : {
|
| 11 |
+
"radius": 3,
|
| 12 |
+
"fpsize": 8192
|
| 13 |
+
},
|
| 14 |
+
"descriptors": ["ecfps", "tox", "maccs", "rdkit_descrs"],
|
| 15 |
+
"feature_selection": {
|
| 16 |
+
"use": "true",
|
| 17 |
+
"min_var": 0.01,
|
| 18 |
+
"max_corr": 0.95,
|
| 19 |
+
"feature_keys": ["ecfps", "tox", "maccs", "rdkit_descrs"],
|
| 20 |
+
"max_features": -1
|
| 21 |
+
},
|
| 22 |
+
"feature_quantilization": {
|
| 23 |
+
"use": "true",
|
| 24 |
+
"feature_keys": ["rdkit_descrs"]
|
| 25 |
+
},
|
| 26 |
+
"max_samples": -1,
|
| 27 |
+
"scaler": "standard",
|
| 28 |
+
|
| 29 |
+
"ckpt_path": "checkpoints/rf_alltasks.joblib",
|
| 30 |
+
"model_configs": {
|
| 31 |
"NR-AR": {
|
| 32 |
"max_depth": "none",
|
| 33 |
"max_features": "sqrt",
|
predict.py
CHANGED
|
@@ -6,15 +6,18 @@ SMILES and target names as keys.
|
|
| 6 |
|
| 7 |
# ---------------------------------------------------------------------------------------
|
| 8 |
# Dependencies
|
|
|
|
|
|
|
| 9 |
from collections import defaultdict
|
| 10 |
|
| 11 |
-
import
|
| 12 |
import numpy as np
|
| 13 |
from tqdm import tqdm
|
| 14 |
|
| 15 |
-
from src.preprocess import create_descriptors
|
| 16 |
-
from src.utils import TASKS, normalize_config
|
| 17 |
from src.model import Tox21RFClassifier
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
# ---------------------------------------------------------------------------------------
|
| 20 |
CONFIG_FILE = "./config/config.json"
|
|
@@ -35,20 +38,29 @@ def predict(
|
|
| 35 |
print(f"Received {len(smiles_list)} SMILES strings")
|
| 36 |
|
| 37 |
with open(CONFIG_FILE, "r") as f:
|
| 38 |
-
|
| 39 |
-
|
| 40 |
|
| 41 |
features, is_clean = create_descriptors(
|
| 42 |
-
smiles_list,
|
| 43 |
)
|
| 44 |
-
|
| 45 |
-
print(f"Created {n_feats} descriptors for {n_clean_mols} molecules.")
|
| 46 |
print(f"{len(is_clean) - sum(is_clean)} molecules removed during cleaning")
|
| 47 |
|
| 48 |
# setup model
|
| 49 |
model = Tox21RFClassifier()
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
# make predicitons
|
| 54 |
predictions = defaultdict(dict)
|
|
@@ -56,24 +68,17 @@ def predict(
|
|
| 56 |
print(f"Create predictions:")
|
| 57 |
preds = []
|
| 58 |
for target in tqdm(TASKS):
|
| 59 |
-
X =
|
| 60 |
-
|
|
|
|
| 61 |
|
|
|
|
| 62 |
preds[~is_clean] = default_prediction
|
| 63 |
preds[is_clean] = model.predict(target, X)
|
| 64 |
|
| 65 |
for smiles, pred in zip(smiles_list, preds):
|
| 66 |
predictions[smiles][target] = float(pred)
|
| 67 |
-
if
|
| 68 |
break
|
| 69 |
|
| 70 |
return predictions
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
# from hiddens.testing import test_eval
|
| 74 |
-
|
| 75 |
-
# with open(CONFIG_FILE, "r") as f:
|
| 76 |
-
# config = json.load(f)
|
| 77 |
-
# config = normalize_config(config)
|
| 78 |
-
|
| 79 |
-
# test_eval(predict, debug=config["debug"], use_only_clean=False, use_only_first=False)
|
|
|
|
| 6 |
|
| 7 |
# ---------------------------------------------------------------------------------------
|
| 8 |
# Dependencies
|
| 9 |
+
import json
|
| 10 |
+
import copy
|
| 11 |
from collections import defaultdict
|
| 12 |
|
| 13 |
+
import joblib
|
| 14 |
import numpy as np
|
| 15 |
from tqdm import tqdm
|
| 16 |
|
|
|
|
|
|
|
| 17 |
from src.model import Tox21RFClassifier
|
| 18 |
+
from src.preprocess import create_descriptors, FeaturePreprocessor
|
| 19 |
+
from src.utils import TASKS, normalize_config
|
| 20 |
+
|
| 21 |
|
| 22 |
# ---------------------------------------------------------------------------------------
|
| 23 |
CONFIG_FILE = "./config/config.json"
|
|
|
|
| 38 |
print(f"Received {len(smiles_list)} SMILES strings")
|
| 39 |
|
| 40 |
with open(CONFIG_FILE, "r") as f:
|
| 41 |
+
config = json.load(f)
|
| 42 |
+
config = normalize_config(config)
|
| 43 |
|
| 44 |
features, is_clean = create_descriptors(
|
| 45 |
+
smiles_list, config["descriptors"], **config["ecfp"]
|
| 46 |
)
|
| 47 |
+
print(f"Created descriptors for {sum(is_clean)} molecules.")
|
|
|
|
| 48 |
print(f"{len(is_clean) - sum(is_clean)} molecules removed during cleaning")
|
| 49 |
|
| 50 |
# setup model
|
| 51 |
model = Tox21RFClassifier()
|
| 52 |
+
preprocessor = FeaturePreprocessor(
|
| 53 |
+
feature_selection_config=config["feature_selection"],
|
| 54 |
+
feature_quantilization_config=config["feature_quantilization"],
|
| 55 |
+
descriptors=config["descriptors"],
|
| 56 |
+
max_samples=config["max_samples"],
|
| 57 |
+
scaler=config["scaler"],
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
ckpt = joblib.load(config["ckpt_path"])
|
| 61 |
+
model.set_state(ckpt["models"])
|
| 62 |
+
preprocessor.__setstate__(ckpt["preprocessor"])
|
| 63 |
+
print(f"Loaded model & preprocessor from {config['ckpt_path']}")
|
| 64 |
|
| 65 |
# make predicitons
|
| 66 |
predictions = defaultdict(dict)
|
|
|
|
| 68 |
print(f"Create predictions:")
|
| 69 |
preds = []
|
| 70 |
for target in tqdm(TASKS):
|
| 71 |
+
X = copy.deepcopy(features)
|
| 72 |
+
X = {descr: array[is_clean] for descr, array in X.items()}
|
| 73 |
+
X = preprocessor.transform(X)
|
| 74 |
|
| 75 |
+
preds = np.empty_like(is_clean, dtype=np.float64)
|
| 76 |
preds[~is_clean] = default_prediction
|
| 77 |
preds[is_clean] = model.predict(target, X)
|
| 78 |
|
| 79 |
for smiles, pred in zip(smiles_list, preds):
|
| 80 |
predictions[smiles][target] = float(pred)
|
| 81 |
+
if config["debug"]:
|
| 82 |
break
|
| 83 |
|
| 84 |
return predictions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
preprocess.py
CHANGED
|
@@ -7,186 +7,64 @@ SMILES and target names as keys.
|
|
| 7 |
"""
|
| 8 |
|
| 9 |
import os
|
|
|
|
| 10 |
import argparse
|
| 11 |
|
| 12 |
import numpy as np
|
| 13 |
|
| 14 |
from src.preprocess import create_descriptors, get_tox21_split
|
| 15 |
-
from src.utils import
|
| 16 |
-
TASKS,
|
| 17 |
-
HF_TOKEN,
|
| 18 |
-
create_dir,
|
| 19 |
-
)
|
| 20 |
|
| 21 |
parser = argparse.ArgumentParser(
|
| 22 |
description="Data preprocessing script for the Tox21 dataset"
|
| 23 |
)
|
| 24 |
|
| 25 |
parser.add_argument(
|
| 26 |
-
"--
|
| 27 |
-
type=str,
|
| 28 |
-
default="data/",
|
| 29 |
-
help="Folder to which preprocessed the data CSV and NPZ files should be saved.",
|
| 30 |
-
)
|
| 31 |
-
|
| 32 |
-
parser.add_argument(
|
| 33 |
-
"--cv_fold",
|
| 34 |
-
type=int,
|
| 35 |
-
default=4,
|
| 36 |
-
help="Select fold used as validation set.",
|
| 37 |
-
)
|
| 38 |
-
|
| 39 |
-
parser.add_argument(
|
| 40 |
-
"--feature_selection",
|
| 41 |
-
type=int,
|
| 42 |
-
default=1,
|
| 43 |
-
help="True (=1) to use feature selection.",
|
| 44 |
-
)
|
| 45 |
-
|
| 46 |
-
parser.add_argument(
|
| 47 |
-
"--feature_selection_path",
|
| 48 |
-
type=str,
|
| 49 |
-
default="feat_selection.npz",
|
| 50 |
-
help="Filename for saving feature selections.",
|
| 51 |
-
)
|
| 52 |
-
|
| 53 |
-
parser.add_argument(
|
| 54 |
-
"--min_var",
|
| 55 |
-
type=float,
|
| 56 |
-
default=0.01,
|
| 57 |
-
help="Minimum variance threshold for selecting features.",
|
| 58 |
-
)
|
| 59 |
-
|
| 60 |
-
parser.add_argument(
|
| 61 |
-
"--max_corr",
|
| 62 |
-
type=float,
|
| 63 |
-
default=0.95,
|
| 64 |
-
help="Maximum correlation threshold for selecting features.",
|
| 65 |
-
)
|
| 66 |
-
|
| 67 |
-
parser.add_argument(
|
| 68 |
-
"--ecdfs_path",
|
| 69 |
type=str,
|
| 70 |
-
default="
|
| 71 |
-
help="Filename to save ECDFs.",
|
| 72 |
)
|
| 73 |
|
| 74 |
-
parser.add_argument(
|
| 75 |
-
"--ecfps_radius",
|
| 76 |
-
type=int,
|
| 77 |
-
default=3,
|
| 78 |
-
help="Radius used for creating ECFPs.",
|
| 79 |
-
)
|
| 80 |
|
| 81 |
-
|
| 82 |
-
"
|
| 83 |
-
|
| 84 |
-
default=8192,
|
| 85 |
-
help="Folds used for creating ECFPs.",
|
| 86 |
-
)
|
| 87 |
|
| 88 |
-
|
| 89 |
-
"--ecdfs",
|
| 90 |
-
type=int,
|
| 91 |
-
default=1,
|
| 92 |
-
help="True (=1) to use ECDFs for creating quantiles of the RDKit descriptors.",
|
| 93 |
-
)
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
def main(args):
|
| 97 |
-
"""Preprocessing train/val data to use for TabPFN.
|
| 98 |
-
|
| 99 |
-
1. Download Tox21 train/val data from HF
|
| 100 |
-
2. Preprocess dataset splits
|
| 101 |
-
"""
|
| 102 |
-
ds = get_tox21_split(HF_TOKEN, cvfold=args.cv_fold)
|
| 103 |
-
|
| 104 |
-
feature_creation_kwargs = {
|
| 105 |
-
"radius": args.ecfps_radius,
|
| 106 |
-
"fpsize": args.ecfps_folds,
|
| 107 |
-
"min_var": args.min_var,
|
| 108 |
-
"max_corr": args.max_corr,
|
| 109 |
-
}
|
| 110 |
-
removed_mols = 0
|
| 111 |
-
|
| 112 |
-
splits = ["train", "validation", "test"]
|
| 113 |
for split in splits:
|
| 114 |
|
| 115 |
print(f"Preprocess {split} molecules")
|
| 116 |
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
ds_split = pd.read_csv("data/tox21_test_cv4.csv")
|
| 124 |
-
|
| 125 |
-
smiles = ds_split["smiles"]
|
| 126 |
-
|
| 127 |
-
features, clean_mol_mask = create_descriptors(smiles, **feature_creation_kwargs)
|
| 128 |
-
|
| 129 |
-
# if split == "train":
|
| 130 |
-
# output = create_descriptors(
|
| 131 |
-
# smiles,
|
| 132 |
-
# return_feature_selection=True,
|
| 133 |
-
# return_ecdfs=True,
|
| 134 |
-
# **feature_creation_kwargs,
|
| 135 |
-
# )
|
| 136 |
-
# features = output.pop("features")
|
| 137 |
-
|
| 138 |
-
# if args.feature_selection:
|
| 139 |
-
# feature_selection = output.pop("feature_selection")
|
| 140 |
-
# np.savez(
|
| 141 |
-
# args.feature_selection_path,
|
| 142 |
-
# ecfps_selec=feature_selection["ecfps_selec"],
|
| 143 |
-
# tox_selec=feature_selection["tox_selec"],
|
| 144 |
-
# )
|
| 145 |
-
|
| 146 |
-
# print(f"Saved feature selection under {args.feature_selection_path}")
|
| 147 |
-
|
| 148 |
-
# if args.ecdfs:
|
| 149 |
-
# ecdfs = output.pop("ecdfs")
|
| 150 |
-
# write_pickle(args.ecdfs_path, ecdfs)
|
| 151 |
-
# print(f"Saved ECDFs under {args.ecdfs_path}")
|
| 152 |
-
|
| 153 |
-
# else:
|
| 154 |
-
# features = create_descriptors(
|
| 155 |
-
# smiles,
|
| 156 |
-
# ecdfs=ecdfs,
|
| 157 |
-
# feature_selection=feature_selection,
|
| 158 |
-
# **feature_creation_kwargs,
|
| 159 |
-
# )["features"]
|
| 160 |
-
removed_mols += (~clean_mol_mask).sum()
|
| 161 |
|
| 162 |
labels = []
|
| 163 |
for task in TASKS:
|
| 164 |
labels.append(ds_split[task].to_numpy())
|
| 165 |
labels = np.stack(labels, axis=1)
|
| 166 |
|
| 167 |
-
save_path = os.path.join(
|
| 168 |
with open(save_path, "wb") as f:
|
| 169 |
np.savez(
|
| 170 |
f,
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
)
|
| 175 |
print(f"Saved preprocessed {split} split under {save_path}")
|
| 176 |
-
print(f"{removed_mols} mols were removed during cleaning across all datasets")
|
| 177 |
print("Preprocessing finished successfully")
|
| 178 |
|
| 179 |
|
| 180 |
if __name__ == "__main__":
|
| 181 |
args = parser.parse_args()
|
| 182 |
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
# )
|
| 187 |
-
|
| 188 |
-
create_dir(args.save_folder)
|
| 189 |
-
# create_dir(args.ecdfs_path, is_file=True)
|
| 190 |
-
# create_dir(args.feature_selection_path, is_file=True)
|
| 191 |
|
| 192 |
-
|
|
|
|
|
|
| 7 |
"""
|
| 8 |
|
| 9 |
import os
|
| 10 |
+
import json
|
| 11 |
import argparse
|
| 12 |
|
| 13 |
import numpy as np
|
| 14 |
|
| 15 |
from src.preprocess import create_descriptors, get_tox21_split
|
| 16 |
+
from src.utils import TASKS, HF_TOKEN, create_dir, normalize_config
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
parser = argparse.ArgumentParser(
|
| 19 |
description="Data preprocessing script for the Tox21 dataset"
|
| 20 |
)
|
| 21 |
|
| 22 |
parser.add_argument(
|
| 23 |
+
"--config",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
type=str,
|
| 25 |
+
default="config/config.json",
|
|
|
|
| 26 |
)
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
+
def main(config):
|
| 30 |
+
"""Create molecule descriptors for HF Tox21 dataset"""
|
| 31 |
+
ds = get_tox21_split(HF_TOKEN, cvfold=config["cvfold"])
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
+
splits = ["train", "validation"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
for split in splits:
|
| 35 |
|
| 36 |
print(f"Preprocess {split} molecules")
|
| 37 |
|
| 38 |
+
ds_split = ds[split]
|
| 39 |
+
smiles = list(ds_split["smiles"])
|
| 40 |
+
|
| 41 |
+
features, clean_mol_mask = create_descriptors(
|
| 42 |
+
smiles, config["descriptors"], **config["ecfp"]
|
| 43 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
labels = []
|
| 46 |
for task in TASKS:
|
| 47 |
labels.append(ds_split[task].to_numpy())
|
| 48 |
labels = np.stack(labels, axis=1)
|
| 49 |
|
| 50 |
+
save_path = os.path.join(config["data_folder"], f"tox21_{split}_cv4.npz")
|
| 51 |
with open(save_path, "wb") as f:
|
| 52 |
np.savez(
|
| 53 |
f,
|
| 54 |
+
clean_mol_mask=clean_mol_mask,
|
| 55 |
+
labels=labels,
|
| 56 |
+
**features,
|
| 57 |
)
|
| 58 |
print(f"Saved preprocessed {split} split under {save_path}")
|
|
|
|
| 59 |
print("Preprocessing finished successfully")
|
| 60 |
|
| 61 |
|
| 62 |
if __name__ == "__main__":
|
| 63 |
args = parser.parse_args()
|
| 64 |
|
| 65 |
+
with open(args.config, "r") as f:
|
| 66 |
+
config = json.load(f)
|
| 67 |
+
config = normalize_config(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
+
create_dir(config["data_folder"])
|
| 70 |
+
main(config)
|
src/model.py
CHANGED
|
@@ -6,15 +6,9 @@ SMILES and target names as keys.
|
|
| 6 |
|
| 7 |
# ---------------------------------------------------------------------------------------
|
| 8 |
# Dependencies
|
| 9 |
-
import os
|
| 10 |
-
import joblib
|
| 11 |
-
|
| 12 |
import numpy as np
|
| 13 |
-
|
| 14 |
from sklearn.ensemble import RandomForestClassifier
|
| 15 |
-
from sklearn.preprocessing import StandardScaler
|
| 16 |
|
| 17 |
-
from .preprocess import get_feature_selection, get_ecdfs, create_quantiles
|
| 18 |
from .utils import TASKS
|
| 19 |
|
| 20 |
|
|
@@ -22,100 +16,34 @@ from .utils import TASKS
|
|
| 22 |
class Tox21RFClassifier:
|
| 23 |
"""A random forest classifier that assigns a toxicity score to a given SMILES string."""
|
| 24 |
|
| 25 |
-
def __init__(
|
| 26 |
-
self, seed: int = 42, task_config: dict = None, rdkit_desc_idxs: list[int] = []
|
| 27 |
-
):
|
| 28 |
"""Initialize a random forest classifier for each of the 12 Tox21 tasks.
|
| 29 |
|
| 30 |
Args:
|
| 31 |
seed (int, optional): seed for RF to ensure reproducibility. Defaults to 42.
|
| 32 |
"""
|
| 33 |
self.tasks = TASKS
|
| 34 |
-
self.rdkit_desc_idxs = rdkit_desc_idxs
|
| 35 |
|
| 36 |
self.models = {
|
| 37 |
task: RandomForestClassifier(
|
| 38 |
random_state=seed,
|
| 39 |
n_jobs=8,
|
| 40 |
-
**(
|
| 41 |
-
{"n_estimators": 1000} if task_config is None else task_config[task]
|
| 42 |
-
),
|
| 43 |
)
|
| 44 |
for task in self.tasks
|
| 45 |
}
|
| 46 |
-
self.feature_selection = None
|
| 47 |
-
self.ecdfs = None
|
| 48 |
-
self.scaler = StandardScaler()
|
| 49 |
-
|
| 50 |
-
def load_model(self, path: str) -> None:
|
| 51 |
-
"""Loads the model from a given path
|
| 52 |
-
|
| 53 |
-
Args:
|
| 54 |
-
path (str): path to model checkpoint
|
| 55 |
-
"""
|
| 56 |
-
model = joblib.load(path)
|
| 57 |
-
|
| 58 |
-
self.models = model["models"]
|
| 59 |
-
self.scaler = model["scalers"]
|
| 60 |
-
self.rdkit_desc_idxs = model["rdkit_desc_idxs"]
|
| 61 |
-
|
| 62 |
-
self.feature_selection = model["feature_selections"]
|
| 63 |
-
self.ecdfs = model["ecdfs"]
|
| 64 |
|
| 65 |
-
def
|
| 66 |
-
"""
|
| 67 |
|
| 68 |
Args:
|
| 69 |
-
|
| 70 |
"""
|
| 71 |
-
|
| 72 |
-
os.makedirs(os.path.dirname(path))
|
| 73 |
-
|
| 74 |
-
model = {
|
| 75 |
-
"models": self.models,
|
| 76 |
-
"feature_selections": self.feature_selection,
|
| 77 |
-
"ecdfs": self.ecdfs,
|
| 78 |
-
"scalers": self.scaler,
|
| 79 |
-
"rdkit_desc_idxs": self.rdkit_desc_idxs,
|
| 80 |
-
}
|
| 81 |
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
X_ = X.copy()
|
| 86 |
-
|
| 87 |
-
_, n_feat = X.shape
|
| 88 |
-
|
| 89 |
-
if self.rdkit_desc_idxs is None:
|
| 90 |
-
self.rdkit_desc_idxs = np.arange(n_feat)
|
| 91 |
-
else:
|
| 92 |
-
assert (
|
| 93 |
-
self.rdkit_desc_idxs < n_feat
|
| 94 |
-
).all(), "passed to_adapt list contains more features than in X!"
|
| 95 |
-
|
| 96 |
-
self.ecdfs = get_ecdfs(X_[:, self.rdkit_desc_idxs])
|
| 97 |
-
X_[:, self.rdkit_desc_idxs] = create_quantiles(
|
| 98 |
-
X_[:, self.rdkit_desc_idxs], self.ecdfs
|
| 99 |
-
)
|
| 100 |
-
|
| 101 |
-
# get feature selection
|
| 102 |
-
self.feature_selection = get_feature_selection(
|
| 103 |
-
X_, min_var=min_var, max_corr=max_corr
|
| 104 |
-
)
|
| 105 |
-
X_ = X_[:, self.feature_selection]
|
| 106 |
-
|
| 107 |
-
# fit scaler
|
| 108 |
-
X_ = self.scaler.fit(X_)
|
| 109 |
-
|
| 110 |
-
def _preprocess(self, X: np.ndarray) -> None:
|
| 111 |
-
X_ = X.copy()
|
| 112 |
-
|
| 113 |
-
X_[:, self.rdkit_desc_idxs] = create_quantiles(
|
| 114 |
-
X_[:, self.rdkit_desc_idxs], self.ecdfs
|
| 115 |
-
)
|
| 116 |
-
X_ = X_[:, self.feature_selection]
|
| 117 |
-
X_ = self.scaler.transform(X_)
|
| 118 |
-
return X_
|
| 119 |
|
| 120 |
def fit(self, task: str, X: np.ndarray, y: np.ndarray) -> None:
|
| 121 |
"""Train the random forest for a given task
|
|
@@ -126,9 +54,8 @@ class Tox21RFClassifier:
|
|
| 126 |
y (np.ndarray): training labels
|
| 127 |
"""
|
| 128 |
assert task in self.tasks, f"Unknown task: {task}"
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
self.models[task].fit(X_, y)
|
| 132 |
|
| 133 |
def predict(self, task: str, X: np.ndarray) -> np.ndarray:
|
| 134 |
"""Predicts labels for a given Tox21 target using molecule features
|
|
@@ -144,6 +71,5 @@ class Tox21RFClassifier:
|
|
| 144 |
assert (
|
| 145 |
len(X.shape) == 2
|
| 146 |
), f"Function expects 2D np.array. Current shape: {X.shape}"
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
return self.models[task].predict_proba(X_)[:, 1]
|
|
|
|
| 6 |
|
| 7 |
# ---------------------------------------------------------------------------------------
|
| 8 |
# Dependencies
|
|
|
|
|
|
|
|
|
|
| 9 |
import numpy as np
|
|
|
|
| 10 |
from sklearn.ensemble import RandomForestClassifier
|
|
|
|
| 11 |
|
|
|
|
| 12 |
from .utils import TASKS
|
| 13 |
|
| 14 |
|
|
|
|
| 16 |
class Tox21RFClassifier:
|
| 17 |
"""A random forest classifier that assigns a toxicity score to a given SMILES string."""
|
| 18 |
|
| 19 |
+
def __init__(self, seed: int = 42, config: dict = None):
|
|
|
|
|
|
|
| 20 |
"""Initialize a random forest classifier for each of the 12 Tox21 tasks.
|
| 21 |
|
| 22 |
Args:
|
| 23 |
seed (int, optional): seed for RF to ensure reproducibility. Defaults to 42.
|
| 24 |
"""
|
| 25 |
self.tasks = TASKS
|
|
|
|
| 26 |
|
| 27 |
self.models = {
|
| 28 |
task: RandomForestClassifier(
|
| 29 |
random_state=seed,
|
| 30 |
n_jobs=8,
|
| 31 |
+
**({"n_estimators": 1000} if config is None else config[task]),
|
|
|
|
|
|
|
| 32 |
)
|
| 33 |
for task in self.tasks
|
| 34 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
+
def set_state(self, state: dict) -> None:
|
| 37 |
+
"""Sets the state of the model
|
| 38 |
|
| 39 |
Args:
|
| 40 |
+
state (dict): models state dict
|
| 41 |
"""
|
| 42 |
+
self.models = state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
+
def get_state(self) -> None:
|
| 45 |
+
"""Return model state dict"""
|
| 46 |
+
return {"models": self.models}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
def fit(self, task: str, X: np.ndarray, y: np.ndarray) -> None:
|
| 49 |
"""Train the random forest for a given task
|
|
|
|
| 54 |
y (np.ndarray): training labels
|
| 55 |
"""
|
| 56 |
assert task in self.tasks, f"Unknown task: {task}"
|
| 57 |
+
_X, _y = X.copy(), y.copy()
|
| 58 |
+
self.models[task].fit(_X, _y)
|
|
|
|
| 59 |
|
| 60 |
def predict(self, task: str, X: np.ndarray) -> np.ndarray:
|
| 61 |
"""Predicts labels for a given Tox21 target using molecule features
|
|
|
|
| 71 |
assert (
|
| 72 |
len(X.shape) == 2
|
| 73 |
), f"Function expects 2D np.array. Current shape: {X.shape}"
|
| 74 |
+
_X = X.copy()
|
| 75 |
+
return self.models[task].predict_proba(_X)[:, 1]
|
|
|
src/preprocess.py
CHANGED
|
@@ -6,20 +6,304 @@ As an input it takes a list of SMILES and it outputs a nested dictionary with
|
|
| 6 |
SMILES and target names as keys.
|
| 7 |
"""
|
| 8 |
|
|
|
|
| 9 |
import json
|
|
|
|
| 10 |
|
| 11 |
import numpy as np
|
| 12 |
import pandas as pd
|
| 13 |
|
| 14 |
from datasets import load_dataset
|
|
|
|
| 15 |
from sklearn.feature_selection import VarianceThreshold
|
|
|
|
| 16 |
from statsmodels.distributions.empirical_distribution import ECDF
|
| 17 |
|
| 18 |
from rdkit import Chem, DataStructs
|
| 19 |
from rdkit.Chem import Descriptors, rdFingerprintGenerator, MACCSkeys
|
| 20 |
from rdkit.Chem.rdchem import Mol
|
| 21 |
|
| 22 |
-
from .utils import USED_200_DESCR, TOX_SMARTS_PATH, Standardizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
def create_cleaned_mol_objects(smiles: list[str]) -> tuple[list[Mol], np.ndarray]:
|
|
@@ -198,112 +482,59 @@ def fill(features, mask, value=np.nan):
|
|
| 198 |
|
| 199 |
def create_descriptors(
|
| 200 |
smiles,
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
return_ecdfs=False,
|
| 204 |
-
return_feature_selection=False,
|
| 205 |
-
**kwargs,
|
| 206 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
# Create cleanded rdkit mol objects
|
| 208 |
mols, clean_mol_mask = create_cleaned_mol_objects(smiles)
|
| 209 |
-
print("Cleaned molecules")
|
| 210 |
-
|
| 211 |
-
tox_patterns = get_tox_patterns(TOX_SMARTS_PATH)
|
| 212 |
|
| 213 |
# Create fingerprints and descriptors
|
| 214 |
-
ecfps
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
# tox_selec = feature_selection["tox_selec"]
|
| 235 |
-
|
| 236 |
-
# ecfps = ecfps[:, ecfps_selec]
|
| 237 |
-
# tox = tox[:, tox_selec]
|
| 238 |
-
|
| 239 |
-
maccs = create_maccs_keys(mols)
|
| 240 |
-
# maccs = fill(maccs, ~clean_mol_mask)
|
| 241 |
-
print("Created MACCS keys")
|
| 242 |
-
|
| 243 |
-
rdkit_descrs = create_rdkit_descriptors(mols)
|
| 244 |
-
# rdkit_descrs = fill(rdkit_descrs, ~clean_mol_mask)
|
| 245 |
-
print("Created RDKit descriptors")
|
| 246 |
-
|
| 247 |
-
# # Create and save ecdfs
|
| 248 |
-
# if ecdfs is None:
|
| 249 |
-
# print("Create ECDFs")
|
| 250 |
-
# ecdfs = []
|
| 251 |
-
# for column in range(rdkit_descrs.shape[1]):
|
| 252 |
-
# raw_values = rdkit_descrs[:, column].reshape(-1)
|
| 253 |
-
# ecdfs.append(ECDF(raw_values))
|
| 254 |
-
|
| 255 |
-
# # Create quantiles
|
| 256 |
-
# rdkit_descr_quantiles = create_quantiles(rdkit_descrs, ecdfs)
|
| 257 |
-
# # expand using mol_mask
|
| 258 |
-
# rdkit_descr_quantiles = fill(rdkit_descr_quantiles, ~clean_mol_mask)
|
| 259 |
-
# print("Created quantiles of RDKit descriptors")
|
| 260 |
|
| 261 |
# concatenate features
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
# "maccs": maccs,
|
| 266 |
-
# "rdkit_descr_quantiles": rdkit_descr_quantiles,
|
| 267 |
-
# }
|
| 268 |
-
# for feat in [ecfps, tox, maccs, rdkit_descrs]:
|
| 269 |
-
# print(feat.shape)
|
| 270 |
-
features = np.concat((ecfps, tox, maccs, rdkit_descrs), axis=1)
|
| 271 |
-
# return_dict = {"features": features}
|
| 272 |
-
# if return_ecdfs:
|
| 273 |
-
# return_dict["ecdfs"] = ecdfs
|
| 274 |
-
# if return_feature_selection:
|
| 275 |
-
# return_dict["feature_selection"] = feature_selection
|
| 276 |
-
return features, clean_mol_mask
|
| 277 |
|
| 278 |
-
|
| 279 |
-
def get_ecdfs(raw_features: np.ndarray, **kwargs) -> np.ndarray:
|
| 280 |
-
ecdfs = []
|
| 281 |
-
for column in range(raw_features.shape[1]):
|
| 282 |
-
raw_values = raw_features[:, column].reshape(-1)
|
| 283 |
-
ecdfs.append(ECDF(raw_values))
|
| 284 |
-
return ecdfs
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
def get_feature_selection(
|
| 288 |
-
raw_features: np.ndarray, min_var=0.01, max_corr=0.95, **kwargs
|
| 289 |
-
) -> np.ndarray:
|
| 290 |
-
# select features with at least min_var variation
|
| 291 |
-
var_thresh = VarianceThreshold(threshold=min_var)
|
| 292 |
-
feature_selection = var_thresh.fit(raw_features).get_support(indices=True)
|
| 293 |
-
|
| 294 |
-
n_features_preselected = len(feature_selection)
|
| 295 |
-
|
| 296 |
-
# Remove highly correlated features
|
| 297 |
-
corr_matrix = np.corrcoef(raw_features[:, feature_selection], rowvar=False)
|
| 298 |
-
upper_tri = np.triu(corr_matrix, k=1)
|
| 299 |
-
to_keep = np.ones((n_features_preselected,), dtype=bool)
|
| 300 |
-
for i in range(upper_tri.shape[0]):
|
| 301 |
-
for j in range(upper_tri.shape[1]):
|
| 302 |
-
if upper_tri[i, j] > max_corr:
|
| 303 |
-
to_keep[j] = False
|
| 304 |
-
|
| 305 |
-
feature_selection = feature_selection[to_keep]
|
| 306 |
-
return feature_selection
|
| 307 |
|
| 308 |
|
| 309 |
def get_tox21_split(token, cvfold=None):
|
|
|
|
| 6 |
SMILES and target names as keys.
|
| 7 |
"""
|
| 8 |
|
| 9 |
+
import copy
|
| 10 |
import json
|
| 11 |
+
from typing import Any
|
| 12 |
|
| 13 |
import numpy as np
|
| 14 |
import pandas as pd
|
| 15 |
|
| 16 |
from datasets import load_dataset
|
| 17 |
+
from sklearn.base import BaseEstimator, TransformerMixin
|
| 18 |
from sklearn.feature_selection import VarianceThreshold
|
| 19 |
+
from sklearn.preprocessing import StandardScaler, FunctionTransformer
|
| 20 |
from statsmodels.distributions.empirical_distribution import ECDF
|
| 21 |
|
| 22 |
from rdkit import Chem, DataStructs
|
| 23 |
from rdkit.Chem import Descriptors, rdFingerprintGenerator, MACCSkeys
|
| 24 |
from rdkit.Chem.rdchem import Mol
|
| 25 |
|
| 26 |
+
from .utils import USED_200_DESCR, TOX_SMARTS_PATH, Standardizer, FeatureDictMixin
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class SquashScaler(TransformerMixin, BaseEstimator):
|
| 30 |
+
"""
|
| 31 |
+
Scaler that performs sequential standardization, nonlinearity (tanh), and
|
| 32 |
+
re-standardization. Inspired by DeepTox (Mayr et al., 2016)
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self):
|
| 36 |
+
self.scaler1 = StandardScaler()
|
| 37 |
+
self.scaler2 = StandardScaler()
|
| 38 |
+
|
| 39 |
+
def fit(self, X):
|
| 40 |
+
_X = X.copy()
|
| 41 |
+
_X = self.scaler1.fit_transform(_X)
|
| 42 |
+
_X = np.tanh(_X)
|
| 43 |
+
_X = self.scaler2.fit(_X)
|
| 44 |
+
return self
|
| 45 |
+
|
| 46 |
+
def transform(self, X):
|
| 47 |
+
_X = X.copy()
|
| 48 |
+
_X = self.scaler1.transform(_X)
|
| 49 |
+
_X = np.tanh(_X)
|
| 50 |
+
return self.scaler2.transform(_X)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
SCALER_REGISTRY = {
|
| 54 |
+
"none": FunctionTransformer,
|
| 55 |
+
"standard": StandardScaler,
|
| 56 |
+
"squash": SquashScaler,
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class SubSampler(TransformerMixin, BaseEstimator):
|
| 61 |
+
"""
|
| 62 |
+
Preprocessor that randomly samples `max_samples` from data.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
max_samples (int): Maximum allowed samples. If -1, all samples are retained.
|
| 66 |
+
|
| 67 |
+
Input:
|
| 68 |
+
np.ndarray: A 2D NumPy array of shape (n_samples, n_features).
|
| 69 |
+
|
| 70 |
+
Output:
|
| 71 |
+
np.ndarray: Subsampled array of shape (min(n_samples, max_samples), n_features).
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
def __init__(self, *, max_samples=-1):
|
| 75 |
+
self.max_samples = max_samples
|
| 76 |
+
self.is_fitted_ = True
|
| 77 |
+
|
| 78 |
+
def fit(self, X: np.ndarray, y: np.ndarray | None = None):
|
| 79 |
+
return self
|
| 80 |
+
|
| 81 |
+
def transform(
|
| 82 |
+
self, X: np.ndarray, y: np.ndarray | None = None
|
| 83 |
+
) -> np.ndarray | tuple[np.ndarray]:
|
| 84 |
+
|
| 85 |
+
_X = X.copy()
|
| 86 |
+
_y = y.copy() if y is not None else None
|
| 87 |
+
|
| 88 |
+
if self.max_samples > 0:
|
| 89 |
+
resample_idxs = np.random.choice(
|
| 90 |
+
np.arange(_X.shape[0]), size=(self.max_samples,), replace=True
|
| 91 |
+
)
|
| 92 |
+
_X = _X[resample_idxs]
|
| 93 |
+
_y = _y[resample_idxs] if _y is not None else None
|
| 94 |
+
|
| 95 |
+
if _y is None:
|
| 96 |
+
return _X
|
| 97 |
+
return _X, _y
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class FeatureSelector(FeatureDictMixin, TransformerMixin, BaseEstimator):
|
| 101 |
+
"""
|
| 102 |
+
Preprocessor that performs feature selection based on variance and correlation.
|
| 103 |
+
|
| 104 |
+
This transformer selects features that:
|
| 105 |
+
1. Have variance above a specified threshold.
|
| 106 |
+
2. Are below a given pairwise correlation threshold.
|
| 107 |
+
3. Among the remaining features, keeps only the top `max_features` with the highest variance.
|
| 108 |
+
|
| 109 |
+
The input and output are both dictionaries mapping feature types to their corresponding
|
| 110 |
+
feature matrices.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
min_var (float): Minimum variance required for a feature to be retained.
|
| 114 |
+
max_corr (float): Maximum allowed correlation between features.
|
| 115 |
+
Features exceeding this threshold with others are removed.
|
| 116 |
+
max_features (int): Maximum number of features to keep after filtering.
|
| 117 |
+
If -1, all remaining features are retained.
|
| 118 |
+
|
| 119 |
+
Input:
|
| 120 |
+
dict[str, np.ndarray]: A dictionary where each key corresponds to a feature type
|
| 121 |
+
and each value is a 2D NumPy array of shape (n_samples, n_features).
|
| 122 |
+
|
| 123 |
+
Output:
|
| 124 |
+
dict[str, np.ndarray]: A dictionary with the same keys as the input,
|
| 125 |
+
containing only the selected features for each feature type.
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
def __init__(
|
| 129 |
+
self, *, min_var=0.0, max_corr=1.0, max_features=-1, feature_keys=None
|
| 130 |
+
):
|
| 131 |
+
self.min_var = min_var
|
| 132 |
+
self.max_corr = max_corr
|
| 133 |
+
self.max_features = max_features
|
| 134 |
+
self._feature_mask = None
|
| 135 |
+
|
| 136 |
+
super().__init__(feature_keys=feature_keys)
|
| 137 |
+
|
| 138 |
+
def fit(self, X: dict[str, np.ndarray]):
|
| 139 |
+
_X = self.dict_to_array(X)
|
| 140 |
+
|
| 141 |
+
# select features with at least min_var variation
|
| 142 |
+
if self.min_var > 0.0:
|
| 143 |
+
var_thresh = VarianceThreshold(threshold=self.min_var)
|
| 144 |
+
feature_mask = var_thresh.fit(_X).get_support() # mask
|
| 145 |
+
|
| 146 |
+
# select features with at least max_var variation
|
| 147 |
+
if self.max_corr < 1.0:
|
| 148 |
+
corr_matrix = np.corrcoef(_X[:, feature_mask], rowvar=False)
|
| 149 |
+
upper_tri = np.triu(corr_matrix, k=1)
|
| 150 |
+
to_keep = np.ones((sum(feature_mask),), dtype=bool)
|
| 151 |
+
for i in range(upper_tri.shape[0]):
|
| 152 |
+
for j in range(upper_tri.shape[1]):
|
| 153 |
+
if upper_tri[i, j] > self.max_corr:
|
| 154 |
+
to_keep[j] = False
|
| 155 |
+
|
| 156 |
+
feature_mask[feature_mask] = to_keep
|
| 157 |
+
|
| 158 |
+
if self.max_features == 0:
|
| 159 |
+
raise ValueError(
|
| 160 |
+
f"max_features (={self.max_features}) must be -1 or larger 0."
|
| 161 |
+
)
|
| 162 |
+
elif self.max_features > 0:
|
| 163 |
+
# select features with at least max_var variation
|
| 164 |
+
feature_vars = np.nanvar(_X[:, feature_mask], axis=0)
|
| 165 |
+
order = np.argsort(feature_vars)[: -(self.max_features + 1) : -1]
|
| 166 |
+
keep_feat_idx = np.arange(feature_mask)[order]
|
| 167 |
+
feature_mask = np.isin(
|
| 168 |
+
np.arange(feature_mask), keep_feat_idx, assume_unique=True
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
self._feature_mask = feature_mask
|
| 172 |
+
self.is_fitted_ = True
|
| 173 |
+
return self
|
| 174 |
+
|
| 175 |
+
def transform(self, X: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
|
| 176 |
+
_X = self.dict_to_array(X)
|
| 177 |
+
_X = _X[:, self._feature_mask]
|
| 178 |
+
self._curr_keys = self._curr_keys[self._feature_mask]
|
| 179 |
+
return self.array_to_dict(_X)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class QuantileCreator(FeatureDictMixin, TransformerMixin, BaseEstimator):
|
| 183 |
+
"""
|
| 184 |
+
Preprocessor that transforms features into empirical quantiles using ECDFs.
|
| 185 |
+
|
| 186 |
+
This transformer applies an Empirical Cumulative Distribution Function (ECDF)
|
| 187 |
+
to each feature and replaces feature values with their corresponding quantile
|
| 188 |
+
ranks. The transformation is applied independently to each feature type.
|
| 189 |
+
|
| 190 |
+
Both input and output are dictionaries mapping feature types to their
|
| 191 |
+
corresponding feature matrices.
|
| 192 |
+
|
| 193 |
+
Input:
|
| 194 |
+
dict[str, np.ndarray]: A dictionary where each key corresponds to a feature type
|
| 195 |
+
and each value is a 2D NumPy array of shape (n_samples, n_features).
|
| 196 |
+
|
| 197 |
+
Output:
|
| 198 |
+
dict[str, np.ndarray]: A dictionary with the same keys as the input,
|
| 199 |
+
where each feature value is replaced by its corresponding ECDF quantile rank.
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
def __init__(self, *, feature_keys=None):
|
| 203 |
+
self._ecdfs = None
|
| 204 |
+
super().__init__(feature_keys=feature_keys)
|
| 205 |
+
|
| 206 |
+
def fit(self, X: dict[str, np.ndarray]):
|
| 207 |
+
_X = self.dict_to_array(X)
|
| 208 |
+
ecdfs = []
|
| 209 |
+
for column in range(_X.shape[1]):
|
| 210 |
+
raw_values = _X[:, column].reshape(-1)
|
| 211 |
+
ecdfs.append(ECDF(raw_values))
|
| 212 |
+
self._ecdfs = ecdfs
|
| 213 |
+
self.is_fitted_ = True
|
| 214 |
+
return self
|
| 215 |
+
|
| 216 |
+
def transform(self, X: dict[str, np.ndarray]) -> np.ndarray:
|
| 217 |
+
_X = self.dict_to_array(X)
|
| 218 |
+
|
| 219 |
+
quantiles = np.zeros_like(_X)
|
| 220 |
+
for column in range(_X.shape[1]):
|
| 221 |
+
raw_values = _X[:, column].reshape(-1)
|
| 222 |
+
ecdf = self._ecdfs[column]
|
| 223 |
+
q = ecdf(raw_values)
|
| 224 |
+
quantiles[:, column] = q
|
| 225 |
+
|
| 226 |
+
return self.array_to_dict(quantiles)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class FeaturePreprocessor(TransformerMixin, BaseEstimator):
|
| 230 |
+
"""This class implements the feature preprocessing from a dictionary of molecule features."""
|
| 231 |
+
|
| 232 |
+
def __init__(
|
| 233 |
+
self,
|
| 234 |
+
feature_selection_config: dict[str, Any],
|
| 235 |
+
feature_quantilization_config: dict[str, Any],
|
| 236 |
+
descriptors: list[str],
|
| 237 |
+
max_samples: int = -1,
|
| 238 |
+
scaler: str = "standard",
|
| 239 |
+
):
|
| 240 |
+
self.descriptors = descriptors
|
| 241 |
+
|
| 242 |
+
self.feature_quantilization_config = feature_quantilization_config
|
| 243 |
+
self.use_feat_quant = self.feature_quantilization_config.pop("use")
|
| 244 |
+
self.quantile_creator = QuantileCreator(**feature_quantilization_config)
|
| 245 |
+
|
| 246 |
+
self.feature_selection_config = feature_selection_config
|
| 247 |
+
self.use_feat_selec = self.feature_selection_config.pop("use")
|
| 248 |
+
self.feature_selector = FeatureSelector(**feature_selection_config)
|
| 249 |
+
|
| 250 |
+
self.max_samples = max_samples
|
| 251 |
+
self.sub_sampler = SubSampler(max_samples=max_samples)
|
| 252 |
+
|
| 253 |
+
self.scaler = SCALER_REGISTRY[scaler]()
|
| 254 |
+
|
| 255 |
+
def __getstate__(self):
|
| 256 |
+
state = super().__getstate__()
|
| 257 |
+
state["quantile_creator"] = self.quantile_creator.__getstate__()
|
| 258 |
+
state["feature_selector"] = self.feature_selector.__getstate__()
|
| 259 |
+
state["sub_sampler"] = self.sub_sampler.__getstate__()
|
| 260 |
+
state["scaler"] = self.scaler.__getstate__()
|
| 261 |
+
return state
|
| 262 |
+
|
| 263 |
+
def __setstate__(self, state):
|
| 264 |
+
_state = copy.deepcopy(state)
|
| 265 |
+
self.quantile_creator.__setstate__(_state.pop("quantile_creator"))
|
| 266 |
+
self.feature_selector.__setstate__(_state.pop("feature_selector"))
|
| 267 |
+
self.sub_sampler.__setstate__(_state.pop("sub_sampler"))
|
| 268 |
+
self.scaler.__setstate__(_state.pop("scaler"))
|
| 269 |
+
super().__setstate__(_state)
|
| 270 |
+
|
| 271 |
+
def fit(self, X: dict[str, np.ndarray]):
|
| 272 |
+
"""Fit the processor transformers"""
|
| 273 |
+
_X = copy.deepcopy(X)
|
| 274 |
+
|
| 275 |
+
if self.use_feat_quant:
|
| 276 |
+
_X = self.quantile_creator.fit_transform(_X)
|
| 277 |
+
|
| 278 |
+
if self.use_feat_selec:
|
| 279 |
+
_X = self.feature_selector.fit_transform(_X)
|
| 280 |
+
|
| 281 |
+
_X = np.concatenate([_X[descr] for descr in self.descriptors], axis=1)
|
| 282 |
+
self.scaler.fit(_X)
|
| 283 |
+
return self
|
| 284 |
+
|
| 285 |
+
def transform(
|
| 286 |
+
self, X: np.ndarray, y: np.ndarray | None = None
|
| 287 |
+
) -> np.ndarray | tuple[np.ndarray]:
|
| 288 |
+
|
| 289 |
+
_X = X.copy()
|
| 290 |
+
_y = y.copy() if y is not None else None
|
| 291 |
+
|
| 292 |
+
if self.use_feat_quant:
|
| 293 |
+
_X = self.quantile_creator.transform(_X)
|
| 294 |
+
|
| 295 |
+
if self.use_feat_selec:
|
| 296 |
+
_X = self.feature_selector.transform(_X)
|
| 297 |
+
|
| 298 |
+
_X = np.concatenate([_X[descr] for descr in self.descriptors], axis=1)
|
| 299 |
+
_X = self.scaler.transform(_X)
|
| 300 |
+
|
| 301 |
+
if _y is None:
|
| 302 |
+
_X = self.sub_sampler.transform(_X)
|
| 303 |
+
return _X
|
| 304 |
+
|
| 305 |
+
_X, _y = self.sub_sampler.transform(_X, _y)
|
| 306 |
+
return _X, _y
|
| 307 |
|
| 308 |
|
| 309 |
def create_cleaned_mol_objects(smiles: list[str]) -> tuple[list[Mol], np.ndarray]:
|
|
|
|
| 482 |
|
| 483 |
def create_descriptors(
|
| 484 |
smiles,
|
| 485 |
+
descriptors,
|
| 486 |
+
**ecfp_kwargs,
|
|
|
|
|
|
|
|
|
|
| 487 |
):
|
| 488 |
+
"""Generate molecular descriptors for multiple SMILES strings.
|
| 489 |
+
|
| 490 |
+
Each SMILES is processed and sanitized using RDKit.
|
| 491 |
+
SMILES that cannot be sanitized are encoded with NaNs, and a corresponding boolean mask
|
| 492 |
+
is returned to indicate which inputs were successfully processed.
|
| 493 |
+
|
| 494 |
+
Args:
|
| 495 |
+
smiles (list[str]): List of SMILES strings for which to generate descriptors.
|
| 496 |
+
descriptors (list[str]): List of descriptor types to compute.
|
| 497 |
+
Supported values include:
|
| 498 |
+
['ecfps', 'tox', 'maccs', 'rdkit_descrs'].
|
| 499 |
+
|
| 500 |
+
Returns:
|
| 501 |
+
tuple[dict[str, np.ndarray], np.ndarray]:
|
| 502 |
+
- A dictionary mapping descriptor names to their computed arrays.
|
| 503 |
+
- A boolean mask of shape (len(smiles),) indicating which SMILES
|
| 504 |
+
were successfully sanitized and processed.
|
| 505 |
+
"""
|
| 506 |
# Create cleanded rdkit mol objects
|
| 507 |
mols, clean_mol_mask = create_cleaned_mol_objects(smiles)
|
| 508 |
+
print(f"Cleaned molecules, {(~clean_mol_mask).sum()} could not be sanitized")
|
|
|
|
|
|
|
| 509 |
|
| 510 |
# Create fingerprints and descriptors
|
| 511 |
+
if "ecfps" in descriptors:
|
| 512 |
+
ecfps = create_ecfp_fps(mols, **ecfp_kwargs)
|
| 513 |
+
ecfps = fill(ecfps, ~clean_mol_mask)
|
| 514 |
+
print("Created ECFP fingerprints")
|
| 515 |
+
|
| 516 |
+
if "tox" in descriptors:
|
| 517 |
+
tox_patterns = get_tox_patterns(TOX_SMARTS_PATH)
|
| 518 |
+
tox = create_tox_features(mols, tox_patterns)
|
| 519 |
+
tox = fill(tox, ~clean_mol_mask)
|
| 520 |
+
print("Created Tox features")
|
| 521 |
+
|
| 522 |
+
if "maccs" in descriptors:
|
| 523 |
+
maccs = create_maccs_keys(mols)
|
| 524 |
+
maccs = fill(maccs, ~clean_mol_mask)
|
| 525 |
+
print("Created MACCS keys")
|
| 526 |
+
|
| 527 |
+
if "rdkit_descrs" in descriptors:
|
| 528 |
+
rdkit_descrs = create_rdkit_descriptors(mols)
|
| 529 |
+
rdkit_descrs = fill(rdkit_descrs, ~clean_mol_mask)
|
| 530 |
+
print("Created RDKit descriptors")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 531 |
|
| 532 |
# concatenate features
|
| 533 |
+
features = {}
|
| 534 |
+
for descr in descriptors:
|
| 535 |
+
features[descr] = vars()[descr]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 536 |
|
| 537 |
+
return features, clean_mol_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 538 |
|
| 539 |
|
| 540 |
def get_tox21_split(token, cvfold=None):
|
src/utils.py
CHANGED
|
@@ -7,6 +7,9 @@
|
|
| 7 |
|
| 8 |
import os
|
| 9 |
import pickle
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
from rdkit import Chem
|
| 12 |
from rdkit.Chem.MolStandardize import rdMolStandardize
|
|
@@ -29,7 +32,7 @@ TASKS = [
|
|
| 29 |
"SR-p53",
|
| 30 |
]
|
| 31 |
|
| 32 |
-
KNOWN_DESCR = ["ecfps", "
|
| 33 |
|
| 34 |
USED_200_DESCR = [
|
| 35 |
0,
|
|
@@ -433,6 +436,63 @@ class Standardizer:
|
|
| 433 |
return mol_out, n_tautomers
|
| 434 |
|
| 435 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
def load_pickle(path: str):
|
| 437 |
with open(path, "rb") as file:
|
| 438 |
content = pickle.load(file)
|
|
@@ -459,7 +519,7 @@ def normalize_config(config: dict):
|
|
| 459 |
for key, val in config.items():
|
| 460 |
if isinstance(val, dict):
|
| 461 |
new_config[key] = normalize_config(val)
|
| 462 |
-
elif val in mapping:
|
| 463 |
new_config[key] = mapping[val]
|
| 464 |
else:
|
| 465 |
new_config[key] = val
|
|
|
|
| 7 |
|
| 8 |
import os
|
| 9 |
import pickle
|
| 10 |
+
from typing import Any
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
|
| 14 |
from rdkit import Chem
|
| 15 |
from rdkit.Chem.MolStandardize import rdMolStandardize
|
|
|
|
| 32 |
"SR-p53",
|
| 33 |
]
|
| 34 |
|
| 35 |
+
KNOWN_DESCR = ["ecfps", "tox", "maccs", "rdkit_descrs"]
|
| 36 |
|
| 37 |
USED_200_DESCR = [
|
| 38 |
0,
|
|
|
|
| 436 |
return mol_out, n_tautomers
|
| 437 |
|
| 438 |
|
| 439 |
+
class FeatureDictMixin:
|
| 440 |
+
"""
|
| 441 |
+
Mixin that enables bidirectional handling of dict-based multi-feature inputs.
|
| 442 |
+
Allows selective removal of columns directly from the combined array.
|
| 443 |
+
|
| 444 |
+
Example input:
|
| 445 |
+
{
|
| 446 |
+
"ecfps": np.ndarray,
|
| 447 |
+
"tox": np.ndarray,
|
| 448 |
+
}
|
| 449 |
+
"""
|
| 450 |
+
|
| 451 |
+
def __init__(self, feature_keys=None):
|
| 452 |
+
self.feature_keys = feature_keys
|
| 453 |
+
self._curr_keys = None
|
| 454 |
+
self._unused_data = None
|
| 455 |
+
|
| 456 |
+
def dict_to_array(self, input: dict[Any, np.ndarray]) -> np.ndarray:
|
| 457 |
+
"""Parse dict input and concatenate into a single array."""
|
| 458 |
+
if not isinstance(input, dict):
|
| 459 |
+
raise TypeError("Input must be a dict {feature_type: np.ndarray, ...}")
|
| 460 |
+
|
| 461 |
+
self._unused_data = {}
|
| 462 |
+
remaining_input = {}
|
| 463 |
+
for key in list(input.keys()):
|
| 464 |
+
if key not in self.feature_keys:
|
| 465 |
+
self._unused_data[key] = input[key]
|
| 466 |
+
else:
|
| 467 |
+
remaining_input[key] = input[key]
|
| 468 |
+
|
| 469 |
+
curr_keys = []
|
| 470 |
+
output = []
|
| 471 |
+
for key in self.feature_keys:
|
| 472 |
+
array = remaining_input.pop(key)
|
| 473 |
+
if array.ndim != 2:
|
| 474 |
+
raise ValueError(f"Feature '{key}' must be 2D, got shape {array.shape}")
|
| 475 |
+
|
| 476 |
+
curr_keys.extend([key] * array.shape[1])
|
| 477 |
+
output.append(array)
|
| 478 |
+
|
| 479 |
+
self._curr_keys = np.array(curr_keys)
|
| 480 |
+
|
| 481 |
+
return np.concatenate(output, axis=1)
|
| 482 |
+
|
| 483 |
+
def array_to_dict(self, input: np.ndarray) -> dict[Any, np.ndarray]:
|
| 484 |
+
"""Reconstruct dict from a concatenated array."""
|
| 485 |
+
if self._curr_keys is None:
|
| 486 |
+
raise ValueError("No feature mapping stored. Did you call parse_input()?")
|
| 487 |
+
|
| 488 |
+
output = {key: input[:, self._curr_keys == key] for key in self.feature_keys}
|
| 489 |
+
output.update(self._unused_data)
|
| 490 |
+
|
| 491 |
+
self._curr_keys = None
|
| 492 |
+
self._unused_data = None
|
| 493 |
+
return output
|
| 494 |
+
|
| 495 |
+
|
| 496 |
def load_pickle(path: str):
|
| 497 |
with open(path, "rb") as file:
|
| 498 |
content = pickle.load(file)
|
|
|
|
| 519 |
for key, val in config.items():
|
| 520 |
if isinstance(val, dict):
|
| 521 |
new_config[key] = normalize_config(val)
|
| 522 |
+
elif isinstance(val, (int, float, str)) and val in mapping:
|
| 523 |
new_config[key] = mapping[val]
|
| 524 |
else:
|
| 525 |
new_config[key] = val
|
train.py
CHANGED
|
@@ -4,6 +4,7 @@ Script for fitting and saving any preprocessing assets, as well as the fitted RF
|
|
| 4 |
|
| 5 |
import os
|
| 6 |
import json
|
|
|
|
| 7 |
import random
|
| 8 |
import logging
|
| 9 |
import argparse
|
|
@@ -12,11 +13,8 @@ import numpy as np
|
|
| 12 |
from datetime import datetime
|
| 13 |
|
| 14 |
from src.model import Tox21RFClassifier
|
| 15 |
-
from src.
|
| 16 |
-
|
| 17 |
-
normalize_config,
|
| 18 |
-
USED_200_DESCR,
|
| 19 |
-
)
|
| 20 |
|
| 21 |
parser = argparse.ArgumentParser(description="RF Training script for Tox21 dataset")
|
| 22 |
|
|
@@ -27,7 +25,7 @@ parser.add_argument(
|
|
| 27 |
)
|
| 28 |
|
| 29 |
|
| 30 |
-
def main(
|
| 31 |
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
| 32 |
|
| 33 |
# setup logger
|
|
@@ -39,7 +37,7 @@ def main(cfg):
|
|
| 39 |
handlers=[
|
| 40 |
logging.FileHandler(
|
| 41 |
os.path.join(
|
| 42 |
-
|
| 43 |
f"{script_name}_{timestamp}.log",
|
| 44 |
)
|
| 45 |
),
|
|
@@ -47,50 +45,50 @@ def main(cfg):
|
|
| 47 |
],
|
| 48 |
)
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
[str(val) for key, val in task_configs.items()]
|
| 54 |
)
|
| 55 |
-
|
| 56 |
-
logger.info(f"Task configs: \n{task_configs_repr}")
|
| 57 |
|
| 58 |
# seeding
|
| 59 |
-
random.seed(
|
| 60 |
-
np.random.seed(
|
| 61 |
-
|
| 62 |
-
train_data = np.load(os.path.join(
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
if
|
| 79 |
logger.info(
|
| 80 |
-
f"Fitted RandomForestClassifier will be saved as: {
|
| 81 |
)
|
| 82 |
else:
|
| 83 |
logger.info("Fitted RandomForestClassifier will NOT be saved.")
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
|
|
|
| 93 |
)
|
|
|
|
| 94 |
|
| 95 |
logger.info("Start training.")
|
| 96 |
for i, task in enumerate(model.tasks):
|
|
@@ -98,28 +96,34 @@ def main(cfg):
|
|
| 98 |
label_mask = ~np.isnan(task_labels)
|
| 99 |
logger.info(f"Fit task {task} using {sum(label_mask)} samples")
|
| 100 |
|
| 101 |
-
task_data =
|
| 102 |
task_labels = task_labels[label_mask].astype(int)
|
| 103 |
|
|
|
|
| 104 |
model.fit(task, task_data, task_labels)
|
| 105 |
-
if
|
| 106 |
break
|
| 107 |
|
| 108 |
log_text = f"Finished training."
|
| 109 |
logger.info(log_text)
|
| 110 |
|
| 111 |
-
if
|
| 112 |
-
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
|
| 116 |
if __name__ == "__main__":
|
| 117 |
args = parser.parse_args()
|
| 118 |
|
| 119 |
with open(args.config, "r") as f:
|
| 120 |
-
|
| 121 |
-
|
| 122 |
|
| 123 |
-
create_dir(
|
| 124 |
|
| 125 |
-
main(
|
|
|
|
| 4 |
|
| 5 |
import os
|
| 6 |
import json
|
| 7 |
+
import joblib
|
| 8 |
import random
|
| 9 |
import logging
|
| 10 |
import argparse
|
|
|
|
| 13 |
from datetime import datetime
|
| 14 |
|
| 15 |
from src.model import Tox21RFClassifier
|
| 16 |
+
from src.preprocess import Tox21Preprocessor
|
| 17 |
+
from src.utils import create_dir, normalize_config,
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
parser = argparse.ArgumentParser(description="RF Training script for Tox21 dataset")
|
| 20 |
|
|
|
|
| 25 |
)
|
| 26 |
|
| 27 |
|
| 28 |
+
def main(config):
|
| 29 |
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
| 30 |
|
| 31 |
# setup logger
|
|
|
|
| 37 |
handlers=[
|
| 38 |
logging.FileHandler(
|
| 39 |
os.path.join(
|
| 40 |
+
config["log_folder"],
|
| 41 |
f"{script_name}_{timestamp}.log",
|
| 42 |
)
|
| 43 |
),
|
|
|
|
| 45 |
],
|
| 46 |
)
|
| 47 |
|
| 48 |
+
logger.info(f"Config: {config}")
|
| 49 |
+
model_configs_repr = "Model configs: \n" + "\n".join(
|
| 50 |
+
[str(val) for val in config["model_configs"].values()]
|
|
|
|
| 51 |
)
|
| 52 |
+
logger.info(f"Model configs: \n{model_configs_repr}")
|
|
|
|
| 53 |
|
| 54 |
# seeding
|
| 55 |
+
random.seed(config["seed"])
|
| 56 |
+
np.random.seed(config["seed"])
|
| 57 |
+
|
| 58 |
+
train_data = np.load(os.path.join(config["data_folder"], "tox21_train_cv4.npz"))
|
| 59 |
+
val_data = np.load(os.path.join(config["data_folder"], "tox21_validation_cv4.npz"))
|
| 60 |
+
|
| 61 |
+
# filter out unsanitized molecules
|
| 62 |
+
train_is_clean = train_data["clean_mol_mask"]
|
| 63 |
+
val_is_clean = val_data["clean_mol_mask"]
|
| 64 |
+
train_data = {descr: array[train_is_clean] for descr, array in train_data.items()}
|
| 65 |
+
val_data = {descr: array[val_is_clean] for descr, array in val_data.items()}
|
| 66 |
+
|
| 67 |
+
# combine datasets
|
| 68 |
+
data = {
|
| 69 |
+
descr: np.concatenate([train_data[descr], val_data[descr]], axis=0)
|
| 70 |
+
for descr in config["descriptors"]
|
| 71 |
+
}
|
| 72 |
+
labels = np.concatenate([train_data["labels"], val_data["labels"]], axis=0)
|
| 73 |
+
|
| 74 |
+
if config["ckpt_path"]:
|
| 75 |
logger.info(
|
| 76 |
+
f"Fitted RandomForestClassifier will be saved as: {config['ckpt_path']}"
|
| 77 |
)
|
| 78 |
else:
|
| 79 |
logger.info("Fitted RandomForestClassifier will NOT be saved.")
|
| 80 |
|
| 81 |
+
model = Tox21RFClassifier(seed=config["seed"], config=config["model_configs"])
|
| 82 |
+
|
| 83 |
+
# setup processors
|
| 84 |
+
preprocessor = Tox21Preprocessor(
|
| 85 |
+
feature_selection_config=config["feature_selection"],
|
| 86 |
+
feature_quantilization_config=config["feature_quantilization"],
|
| 87 |
+
descriptors=config["descriptors"],
|
| 88 |
+
max_samples=config["max_samples"],
|
| 89 |
+
scaler=config["scaler"],
|
| 90 |
)
|
| 91 |
+
preprocessor.fit(data)
|
| 92 |
|
| 93 |
logger.info("Start training.")
|
| 94 |
for i, task in enumerate(model.tasks):
|
|
|
|
| 96 |
label_mask = ~np.isnan(task_labels)
|
| 97 |
logger.info(f"Fit task {task} using {sum(label_mask)} samples")
|
| 98 |
|
| 99 |
+
task_data = {key: val[label_mask] for key, val in data.items()}
|
| 100 |
task_labels = task_labels[label_mask].astype(int)
|
| 101 |
|
| 102 |
+
task_data = preprocessor.transform(task_data)
|
| 103 |
model.fit(task, task_data, task_labels)
|
| 104 |
+
if config["debug"]:
|
| 105 |
break
|
| 106 |
|
| 107 |
log_text = f"Finished training."
|
| 108 |
logger.info(log_text)
|
| 109 |
|
| 110 |
+
if config["ckpt_path"]:
|
| 111 |
+
ckpt = {
|
| 112 |
+
"preprocessor": preprocessor.__getstate__(),
|
| 113 |
+
"models": model.get_state(),
|
| 114 |
+
}
|
| 115 |
+
# model.save_model(config["ckpt_path"])
|
| 116 |
+
joblib.dump(ckpt, config["ckpt_path"])
|
| 117 |
+
logger.info(f"Save model as: {config['ckpt_path']}")
|
| 118 |
|
| 119 |
|
| 120 |
if __name__ == "__main__":
|
| 121 |
args = parser.parse_args()
|
| 122 |
|
| 123 |
with open(args.config, "r") as f:
|
| 124 |
+
config = json.load(f)
|
| 125 |
+
config = normalize_config(config)
|
| 126 |
|
| 127 |
+
create_dir(config["log_folder"])
|
| 128 |
|
| 129 |
+
main(config)
|