""" This files includes a predict function for the Tox21. As an input it takes a list of SMILES and it outputs a nested dictionary with SMILES and target names as keys. """ # --------------------------------------------------------------------------------------- # Dependencies from collections import defaultdict import numpy as np from src.data import create_descriptors from src.utils import load_pickle, KNOWN_DESCR from src.model import Tox21RFClassifier # --------------------------------------------------------------------------------------- def predict(smiles_list: list[str]) -> dict[str, dict[str, float]]: """Applies the classifier to a list of SMILES strings. Returns prediction=0.0 for any molecule that could not be cleaned. Args: smiles_list (list[str]): list of SMILES strings Returns: dict: nested prediction dictionary, following {'': {'': }} """ print(f"Received {len(smiles_list)} SMILES strings") # preprocessing pipeline ecdfs_path = "assets/ecdfs.pkl" scaler_path = "assets/scaler.pkl" ecdfs = load_pickle(ecdfs_path) scaler = load_pickle(scaler_path) print(f"Loaded ecdfs from {ecdfs_path}") print(f"Loaded scaler from {scaler_path}") descriptors = KNOWN_DESCR features, mol_mask = create_descriptors( smiles_list, ecdfs=ecdfs, scaler=scaler, descriptors=descriptors, ) print(f"Created descriptors {descriptors} for molecules.") print(f"{len(mol_mask) - sum(mol_mask)} molecules removed during cleaning") # setup model model = Tox21RFClassifier(seed=42) model_path = "assets/rf_alltasks.joblib" model.load_model(model_path) print(f"Loaded model from {model_path}") # make predicitons predictions = defaultdict(dict) # create a list with same length as smiles_list to obtain indices for respective features feat_indices = np.cumsum(mol_mask) - 1 for target in model.tasks: target_pred = model.predict(target, features) for smiles, is_clean, i in zip(smiles_list, mol_mask, feat_indices): predictions[smiles][target] = float(target_pred[i]) if is_clean else 0.0 return predictions