""" Script for fitting and saving any preprocessing assets, as well as the fitted RandomForest model """ import numpy as np from tabulate import tabulate from datasets import load_dataset from sklearn.metrics import roc_auc_score from data import preprocess_molecules from model import Tox21RFClassifier from utils import HF_TOKEN def get_sample_mask(removed_idxs: list[int], labels: np.ndarray): # mask out NaN labels and labels of removed idxs task_mask = ~np.isnan(labels) removed_mask = np.ones_like(labels, dtype=bool) removed_mask[removed_idxs] = 0 feature_mask = task_mask[removed_mask] label_mask = np.logical_and(task_mask, removed_mask) return feature_mask, label_mask def main(): # save preprocessing scaler and ecdf distributions save_folder = "assets/model/" ds = load_dataset("tschouis/tox21", token=HF_TOKEN) print("Preprocess train molecules") train_smiles = list(ds["train"]["smiles"]) train_features, train_removed_idxs = preprocess_molecules( train_smiles, save_ecdf_path="assets/ecdfs.pkl", save_scaler_path="assets/scaler.pkl", ) print("Preprocess validation molecules") val_smiles = list(ds["validation"]["smiles"]) val_features, val_removed_idxs = preprocess_molecules( val_smiles, load_ecdf_path="assets/ecdfs.pkl", load_scaler_path="assets/scaler.pkl", ) model = Tox21RFClassifier(seed=42) print("Start training.") for task in model.tasks: task_labels = ds["train"].to_pandas()[task].to_numpy() feature_mask, label_mask = get_sample_mask(train_removed_idxs, task_labels) print(f"Fit task {task} using {sum(label_mask)} samples") model.fit( task, train_features[feature_mask], task_labels[label_mask].astype(int) ) print(f"Save model under {save_folder}") # model.save_model(save_folder) print("Evaluate model") results = {} for task in model.tasks: task_labels = ds["validation"].to_pandas()[task].to_numpy() feature_mask, label_mask = get_sample_mask(val_removed_idxs, task_labels) pred = model.predict(task, val_features[feature_mask]) results[task] = [ roc_auc_score(y_true=task_labels[label_mask].astype(int), y_score=pred) ] print("Results:") print(tabulate(results, headers="keys")) if __name__ == "__main__": main()