antoniaebner's picture
adapt load/saving, preprocessing, app, readme, modelcard
97697e0
"""
This files includes a RF model for Tox21.
As an input it takes a list of SMILES and it outputs a nested dictionary with
SMILES and target names as keys.
"""
# ---------------------------------------------------------------------------------------
# Dependencies
import joblib
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from .utils import TASKS
# ---------------------------------------------------------------------------------------
class Tox21RFClassifier:
"""A random forest classifier that assigns a toxicity score to a given SMILES string."""
def __init__(self, seed: int = 42, config: dict = None):
"""Initialize a random forest classifier for each of the 12 Tox21 tasks.
Args:
seed (int, optional): seed for RF to ensure reproducibility. Defaults to 42.
"""
self.tasks = TASKS
self.models = {
task: RandomForestClassifier(
random_state=seed,
n_jobs=8,
**({"n_estimators": 1000} if config is None else config[task]),
)
for task in self.tasks
}
def load(self, path: str) -> None:
"""Load model from filepath
Args:
path (str): filepath to model checkpoint
"""
self.models = joblib.load(path)
def save(self, path: str) -> None:
"""Save model to filepath
Args:
path (str): filepath to model checkpoint
"""
joblib.dump(self.models, path)
def fit(self, task: str, X: np.ndarray, y: np.ndarray) -> None:
"""Train the random forest for a given task
Args:
task (str): task to train
X (np.ndarray): training features
y (np.ndarray): training labels
"""
assert task in self.tasks, f"Unknown task: {task}"
_X, _y = X.copy(), y.copy()
self.models[task].fit(_X, _y)
def predict(self, task: str, X: np.ndarray) -> np.ndarray:
"""Predicts labels for a given Tox21 target using molecule features
Args:
task (str): the Tox21 target to predict for
X (np.ndarray): molecule features used for prediction
Returns:
np.ndarray: predicted probability for positive class
"""
assert task in self.tasks, f"Unknown task: {task}"
assert (
len(X.shape) == 2
), f"Function expects 2D np.array. Current shape: {X.shape}"
_X = X.copy()
return self.models[task].predict_proba(_X)[:, 1]