Spaces:
Sleeping
Sleeping
| """ | |
| This files includes a RF model for Tox21. | |
| As an input it takes a list of SMILES and it outputs a nested dictionary with | |
| SMILES and target names as keys. | |
| """ | |
| # --------------------------------------------------------------------------------------- | |
| # Dependencies | |
| import joblib | |
| import numpy as np | |
| from sklearn.ensemble import RandomForestClassifier | |
| from .utils import TASKS | |
| # --------------------------------------------------------------------------------------- | |
| class Tox21RFClassifier: | |
| """A random forest classifier that assigns a toxicity score to a given SMILES string.""" | |
| def __init__(self, seed: int = 42, config: dict = None): | |
| """Initialize a random forest classifier for each of the 12 Tox21 tasks. | |
| Args: | |
| seed (int, optional): seed for RF to ensure reproducibility. Defaults to 42. | |
| """ | |
| self.tasks = TASKS | |
| self.models = { | |
| task: RandomForestClassifier( | |
| random_state=seed, | |
| n_jobs=8, | |
| **({"n_estimators": 1000} if config is None else config[task]), | |
| ) | |
| for task in self.tasks | |
| } | |
| def load(self, path: str) -> None: | |
| """Load model from filepath | |
| Args: | |
| path (str): filepath to model checkpoint | |
| """ | |
| self.models = joblib.load(path) | |
| def save(self, path: str) -> None: | |
| """Save model to filepath | |
| Args: | |
| path (str): filepath to model checkpoint | |
| """ | |
| joblib.dump(self.models, path) | |
| def fit(self, task: str, X: np.ndarray, y: np.ndarray) -> None: | |
| """Train the random forest for a given task | |
| Args: | |
| task (str): task to train | |
| X (np.ndarray): training features | |
| y (np.ndarray): training labels | |
| """ | |
| assert task in self.tasks, f"Unknown task: {task}" | |
| _X, _y = X.copy(), y.copy() | |
| self.models[task].fit(_X, _y) | |
| def predict(self, task: str, X: np.ndarray) -> np.ndarray: | |
| """Predicts labels for a given Tox21 target using molecule features | |
| Args: | |
| task (str): the Tox21 target to predict for | |
| X (np.ndarray): molecule features used for prediction | |
| Returns: | |
| np.ndarray: predicted probability for positive class | |
| """ | |
| assert task in self.tasks, f"Unknown task: {task}" | |
| assert ( | |
| len(X.shape) == 2 | |
| ), f"Function expects 2D np.array. Current shape: {X.shape}" | |
| _X = X.copy() | |
| return self.models[task].predict_proba(_X)[:, 1] | |