File size: 2,565 Bytes
81226cb
 
 
 
 
 
 
 
97697e0
81226cb
 
 
d84754c
82027a5
81226cb
 
 
a8d912f
81226cb
1994acc
a8d912f
 
 
 
 
81226cb
3fd3838
 
 
 
 
1994acc
82027a5
3fd3838
 
a8d912f
97697e0
 
a8d912f
 
97697e0
81226cb
97697e0
3fd3838
97697e0
 
 
 
 
 
 
33fd417
3fd3838
a8d912f
 
 
 
3fd3838
 
a8d912f
81226cb
1994acc
 
3fd3838
 
a8d912f
 
 
 
3fd3838
a8d912f
 
 
81226cb
 
117adda
3fd3838
 
1994acc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""
This files includes a RF model for Tox21.
As an input it takes a list of SMILES and it outputs a nested dictionary with
SMILES and target names as keys.
"""

# ---------------------------------------------------------------------------------------
# Dependencies
import joblib
import numpy as np
from sklearn.ensemble import RandomForestClassifier

from .utils import TASKS


# ---------------------------------------------------------------------------------------
class Tox21RFClassifier:
    """A random forest classifier that assigns a toxicity score to a given SMILES string."""

    def __init__(self, seed: int = 42, config: dict = None):
        """Initialize a random forest classifier for each of the 12 Tox21 tasks.

        Args:
            seed (int, optional): seed for RF to ensure reproducibility. Defaults to 42.
        """
        self.tasks = TASKS

        self.models = {
            task: RandomForestClassifier(
                random_state=seed,
                n_jobs=8,
                **({"n_estimators": 1000} if config is None else config[task]),
            )
            for task in self.tasks
        }

    def load(self, path: str) -> None:
        """Load model from filepath

        Args:
            path (str): filepath to model checkpoint
        """
        self.models = joblib.load(path)

    def save(self, path: str) -> None:
        """Save model to filepath

        Args:
            path (str): filepath to model checkpoint
        """
        joblib.dump(self.models, path)

    def fit(self, task: str, X: np.ndarray, y: np.ndarray) -> None:
        """Train the random forest for a given task

        Args:
            task (str): task to train
            X (np.ndarray): training features
            y (np.ndarray): training labels
        """
        assert task in self.tasks, f"Unknown task: {task}"
        _X, _y = X.copy(), y.copy()
        self.models[task].fit(_X, _y)

    def predict(self, task: str, X: np.ndarray) -> np.ndarray:
        """Predicts labels for a given Tox21 target using molecule features

        Args:
            task (str): the Tox21 target to predict for
            X (np.ndarray): molecule features used for prediction

        Returns:
            np.ndarray: predicted probability for positive class
        """
        assert task in self.tasks, f"Unknown task: {task}"
        assert (
            len(X.shape) == 2
        ), f"Function expects 2D np.array. Current shape: {X.shape}"
        _X = X.copy()
        return self.models[task].predict_proba(_X)[:, 1]