""" This files includes a RF model for Tox21. As an input it takes a list of SMILES and it outputs a nested dictionary with SMILES and target names as keys. """ # --------------------------------------------------------------------------------------- # Dependencies import joblib import numpy as np from sklearn.ensemble import RandomForestClassifier from .utils import TASKS # --------------------------------------------------------------------------------------- class Tox21RFClassifier: """A random forest classifier that assigns a toxicity score to a given SMILES string.""" def __init__(self, seed: int = 42, config: dict = None): """Initialize a random forest classifier for each of the 12 Tox21 tasks. Args: seed (int, optional): seed for RF to ensure reproducibility. Defaults to 42. """ self.tasks = TASKS self.models = { task: RandomForestClassifier( random_state=seed, n_jobs=8, **({"n_estimators": 1000} if config is None else config[task]), ) for task in self.tasks } def load(self, path: str) -> None: """Load model from filepath Args: path (str): filepath to model checkpoint """ self.models = joblib.load(path) def save(self, path: str) -> None: """Save model to filepath Args: path (str): filepath to model checkpoint """ joblib.dump(self.models, path) def fit(self, task: str, X: np.ndarray, y: np.ndarray) -> None: """Train the random forest for a given task Args: task (str): task to train X (np.ndarray): training features y (np.ndarray): training labels """ assert task in self.tasks, f"Unknown task: {task}" _X, _y = X.copy(), y.copy() self.models[task].fit(_X, _y) def predict(self, task: str, X: np.ndarray) -> np.ndarray: """Predicts labels for a given Tox21 target using molecule features Args: task (str): the Tox21 target to predict for X (np.ndarray): molecule features used for prediction Returns: np.ndarray: predicted probability for positive class """ assert task in self.tasks, f"Unknown task: {task}" assert ( len(X.shape) == 2 ), f"Function expects 2D np.array. Current shape: {X.shape}" _X = X.copy() return self.models[task].predict_proba(_X)[:, 1]