""" This files includes a predict function for the Tox21. As an input it takes a list of SMILES and it outputs a nested dictionary with SMILES and target names as keys. """ # --------------------------------------------------------------------------------------- # Dependencies from typing import List from collections import defaultdict from data import preprocess_molecules from model import Tox21RFClassifier # --------------------------------------------------------------------------------------- def predict(smiles_list: List[str]) -> dict: """ Applies the classifier to a list of SMILES strings. """ # preprocessing pipeline features, removed_idxs = preprocess_molecules( smiles_list, load_ecdf_path="assets/ecdfs.pkl", load_scaler_path="assets/scaler.pkl", ) # setup model model = Tox21RFClassifier(seed=42) model.load_model("assets/model/") # make predicitons predictions = defaultdict(dict) for i, smiles in enumerate(smiles_list): for target in model.tasks: predictions[smiles][target] = ( 0.0 if i in removed_idxs else model.predict(target, features[i]) ) return predictions