Spaces:
Sleeping
Sleeping
File size: 2,565 Bytes
81226cb 97697e0 81226cb d84754c 82027a5 81226cb a8d912f 81226cb 1994acc a8d912f 81226cb 3fd3838 1994acc 82027a5 3fd3838 a8d912f 97697e0 a8d912f 97697e0 81226cb 97697e0 3fd3838 97697e0 33fd417 3fd3838 a8d912f 3fd3838 a8d912f 81226cb 1994acc 3fd3838 a8d912f 3fd3838 a8d912f 81226cb 117adda 3fd3838 1994acc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
"""
This files includes a RF model for Tox21.
As an input it takes a list of SMILES and it outputs a nested dictionary with
SMILES and target names as keys.
"""
# ---------------------------------------------------------------------------------------
# Dependencies
import joblib
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from .utils import TASKS
# ---------------------------------------------------------------------------------------
class Tox21RFClassifier:
"""A random forest classifier that assigns a toxicity score to a given SMILES string."""
def __init__(self, seed: int = 42, config: dict = None):
"""Initialize a random forest classifier for each of the 12 Tox21 tasks.
Args:
seed (int, optional): seed for RF to ensure reproducibility. Defaults to 42.
"""
self.tasks = TASKS
self.models = {
task: RandomForestClassifier(
random_state=seed,
n_jobs=8,
**({"n_estimators": 1000} if config is None else config[task]),
)
for task in self.tasks
}
def load(self, path: str) -> None:
"""Load model from filepath
Args:
path (str): filepath to model checkpoint
"""
self.models = joblib.load(path)
def save(self, path: str) -> None:
"""Save model to filepath
Args:
path (str): filepath to model checkpoint
"""
joblib.dump(self.models, path)
def fit(self, task: str, X: np.ndarray, y: np.ndarray) -> None:
"""Train the random forest for a given task
Args:
task (str): task to train
X (np.ndarray): training features
y (np.ndarray): training labels
"""
assert task in self.tasks, f"Unknown task: {task}"
_X, _y = X.copy(), y.copy()
self.models[task].fit(_X, _y)
def predict(self, task: str, X: np.ndarray) -> np.ndarray:
"""Predicts labels for a given Tox21 target using molecule features
Args:
task (str): the Tox21 target to predict for
X (np.ndarray): molecule features used for prediction
Returns:
np.ndarray: predicted probability for positive class
"""
assert task in self.tasks, f"Unknown task: {task}"
assert (
len(X.shape) == 2
), f"Function expects 2D np.array. Current shape: {X.shape}"
_X = X.copy()
return self.models[task].predict_proba(_X)[:, 1]
|