Spaces:
Sleeping
Sleeping
| import copy | |
| import json | |
| from typing import Any | |
| import numpy as np | |
| import pandas as pd | |
| from datasets import load_dataset | |
| from sklearn.base import BaseEstimator, TransformerMixin | |
| from sklearn.feature_selection import VarianceThreshold | |
| from sklearn.preprocessing import StandardScaler, FunctionTransformer | |
| from statsmodels.distributions.empirical_distribution import ECDF | |
| from rdkit import Chem, DataStructs | |
| from rdkit.Chem import Descriptors, rdFingerprintGenerator, MACCSkeys | |
| from rdkit.Chem.rdchem import Mol | |
| from .utils import USED_200_DESCR, TOX_SMARTS_PATH, Standardizer, FeatureDictMixin | |
| class SquashScaler(TransformerMixin, BaseEstimator): | |
| """ | |
| Scaler that performs sequential standardization, nonlinearity (tanh), and | |
| re-standardization. Inspired by DeepTox (Mayr et al., 2016) | |
| """ | |
| def __init__(self): | |
| self.scaler1 = StandardScaler() | |
| self.scaler2 = StandardScaler() | |
| def fit(self, X): | |
| _X = X.copy() | |
| _X = self.scaler1.fit_transform(_X) | |
| _X = np.tanh(_X) | |
| _X = self.scaler2.fit(_X) | |
| self.is_fitted_ = True | |
| return self | |
| def transform(self, X): | |
| _X = X.copy() | |
| _X = self.scaler1.transform(_X) | |
| _X = np.tanh(_X) | |
| return self.scaler2.transform(_X) | |
| SCALER_REGISTRY = { | |
| None: FunctionTransformer, | |
| "standard": StandardScaler, | |
| "squash": SquashScaler, | |
| } | |
| class SubSampler(TransformerMixin, BaseEstimator): | |
| """ | |
| Preprocessor that randomly samples `max_samples` from data. | |
| Args: | |
| max_samples (int): Maximum allowed samples. If -1, all samples are retained. | |
| Input: | |
| np.ndarray: A 2D NumPy array of shape (n_samples, n_features). | |
| Output: | |
| np.ndarray: Subsampled array of shape (min(n_samples, max_samples), n_features). | |
| """ | |
| def __init__(self, *, max_samples=-1): | |
| self.max_samples = max_samples | |
| self.is_fitted_ = True | |
| def fit(self, X: np.ndarray, y: np.ndarray | None = None): | |
| return self | |
| def transform( | |
| self, X: np.ndarray, y: np.ndarray | None = None | |
| ) -> np.ndarray | tuple[np.ndarray]: | |
| _X = X.copy() | |
| _y = y.copy() if y is not None else None | |
| if self.max_samples > 0 and _X.shape[0] > self.max_samples: | |
| resample_idxs = np.random.choice( | |
| np.arange(_X.shape[0]), size=(self.max_samples,), replace=True | |
| ) | |
| _X = _X[resample_idxs] | |
| _y = _y[resample_idxs] if _y is not None else None | |
| if _y is None: | |
| return _X | |
| return _X, _y | |
| class FeatureSelector(FeatureDictMixin, TransformerMixin, BaseEstimator): | |
| """ | |
| Preprocessor that performs feature selection based on variance and correlation. | |
| This transformer selects features that: | |
| 1. Have variance above a specified threshold. | |
| 2. Are below a given pairwise correlation threshold. | |
| 3. Among the remaining features, keeps only the top `max_features` with the highest variance. | |
| The input and output are both dictionaries mapping feature types to their corresponding | |
| feature matrices. | |
| Args: | |
| min_var (float): Minimum variance required for a feature to be retained. | |
| max_corr (float): Maximum allowed correlation between features. | |
| Features exceeding this threshold with others are removed. | |
| max_features (int): Maximum number of features to keep after filtering. | |
| If -1, all remaining features are retained. | |
| feature_keys (list[str]): Features to apply feature selection to. | |
| independent_keys (bool): Apply filtering only within features types. | |
| Input: | |
| dict[str, np.ndarray]: A dictionary where each key corresponds to a feature type | |
| and each value is a 2D NumPy array of shape (n_samples, n_features). | |
| Output: | |
| dict[str, np.ndarray]: A dictionary with the same keys as the input, | |
| containing only the selected features for each feature type. | |
| """ | |
| def __init__( | |
| self, | |
| *, | |
| min_var=0.0, | |
| max_corr=1.0, | |
| max_features=-1, | |
| feature_keys=None, | |
| min_var__feature_keys=None, | |
| max_corr__feature_keys=None, | |
| max_features__feature_keys=None, | |
| min_var__independent_keys=False, | |
| max_corr__independent_keys=False, | |
| max_features__independent_keys=False, | |
| ): | |
| self.min_var = min_var | |
| self.max_corr = max_corr | |
| self.max_features = max_features | |
| self.min_var__feature_keys = min_var__feature_keys | |
| self.max_corr__feature_keys = max_corr__feature_keys | |
| self.max_features__feature_keys = max_features__feature_keys | |
| self.min_var__independent_keys = min_var__independent_keys | |
| self.max_corr__independent_keys = max_corr__independent_keys | |
| self.max_features__independent_keys = max_features__independent_keys | |
| super().__init__(feature_keys=feature_keys) | |
| def _get_min_var_mask(self, X: np.ndarray, *args) -> np.ndarray: | |
| var_thresh = VarianceThreshold(threshold=self.min_var) | |
| return var_thresh.fit(X).get_support() # mask | |
| def _get_max_corr_mask( | |
| self, X: np.ndarray, prev_feature_mask: np.ndarray | |
| ) -> np.ndarray: | |
| _prev_feature_mask = prev_feature_mask.copy() | |
| corr_matrix = np.corrcoef(X[:, _prev_feature_mask], rowvar=False) | |
| upper_tri = np.triu(corr_matrix, k=1) | |
| to_keep = np.ones((sum(_prev_feature_mask),), dtype=bool) | |
| for i in range(upper_tri.shape[0]): | |
| for j in range(upper_tri.shape[1]): | |
| if upper_tri[i, j] > self.max_corr: | |
| to_keep[j] = False | |
| _prev_feature_mask[_prev_feature_mask] = to_keep | |
| return _prev_feature_mask | |
| def _get_max_features_mask( | |
| self, X: np.ndarray, prev_feature_mask: np.ndarray | |
| ) -> np.ndarray: | |
| _prev_feature_mask = prev_feature_mask.copy() | |
| # select features with at least max_var variation | |
| feature_vars = np.nanvar(X[:, _prev_feature_mask], axis=0) | |
| order = np.argsort(feature_vars)[: -(self.max_features + 1) : -1] | |
| keep_feat_idx = np.arange(len(_prev_feature_mask))[order] | |
| _prev_feature_mask = np.isin( | |
| np.arange(len(_prev_feature_mask)), keep_feat_idx, assume_unique=True | |
| ) | |
| return _prev_feature_mask | |
| def apply_filter(self, filter, X, prev_feature_mask): | |
| mask = prev_feature_mask.copy() | |
| func = self.__getattribute__(f"_get_{filter}_mask") | |
| feature_keys = self.__getattribute__(f"{filter}__feature_keys") | |
| if self.__getattribute__(f"{filter}__independent_keys"): | |
| for key in feature_keys: | |
| key_mask = self._curr_keys == key | |
| mask[key_mask] = func(X[:, key_mask], mask[key_mask]) | |
| else: | |
| feature_key_mask = np.isin(self._curr_keys, feature_keys) | |
| mask[feature_key_mask] = func( | |
| X[:, feature_key_mask], mask[feature_key_mask] | |
| ) | |
| return mask | |
| def fit(self, X: dict[str, np.ndarray]): | |
| _X = self.dict_to_array(X) | |
| feature_mask = np.ones((_X.shape[1]), dtype=bool) | |
| # select features with at least min_var variation | |
| if self.min_var > 0.0: | |
| if self.min_var__independent_keys: | |
| for key in self.min_var__feature_keys: | |
| key_mask = self._curr_keys == key | |
| feature_mask[key_mask] = self._get_min_var_mask(_X[:, key_mask]) | |
| else: | |
| feature_key_mask = np.isin(self._curr_keys, self.min_var__feature_keys) | |
| feature_mask[feature_key_mask] = self._get_min_var_mask( | |
| _X[:, feature_key_mask] | |
| ) | |
| # select features with at least max_var variation | |
| if self.max_corr < 1.0: | |
| if self.max_corr__independent_keys: | |
| for key in self.max_corr__feature_keys: | |
| key_mask = self._curr_keys == key | |
| subset = _X[:, key_mask] | |
| feature_mask[key_mask] = self._get_max_corr_mask( | |
| subset, feature_mask[key_mask] | |
| ) | |
| else: | |
| feature_key_mask = np.isin(self._curr_keys, self.max_corr__feature_keys) | |
| feature_mask[feature_key_mask] = self._get_max_corr_mask( | |
| _X[:, feature_key_mask], feature_mask[feature_key_mask] | |
| ) | |
| if self.max_features == 0: | |
| raise ValueError( | |
| f"max_features (={self.max_features}) must be -1 or larger 0." | |
| ) | |
| elif self.max_features > 0: | |
| if self.max_features__independent_keys: | |
| for key in self.max_features__feature_keys: | |
| key_mask = self._curr_keys == key | |
| feature_mask[key_mask] = self._get_max_features_mask( | |
| _X[:, key_mask], feature_mask[key_mask] | |
| ) | |
| else: | |
| feature_key_mask = np.isin( | |
| self._curr_keys, self.max_features__feature_keys | |
| ) | |
| feature_mask[feature_key_mask] = self._get_max_features_mask( | |
| _X[:, feature_key_mask], feature_mask[feature_key_mask] | |
| ) | |
| self._feature_mask = feature_mask | |
| self.is_fitted_ = True | |
| return self | |
| def transform(self, X: dict[str, np.ndarray]) -> dict[str, np.ndarray]: | |
| _X = self.dict_to_array(X) | |
| _X = _X[:, self._feature_mask] | |
| self._curr_keys = self._curr_keys[self._feature_mask] | |
| return self.array_to_dict(_X) | |
| class QuantileCreator(FeatureDictMixin, TransformerMixin, BaseEstimator): | |
| """ | |
| Preprocessor that transforms features into empirical quantiles using ECDFs. | |
| This transformer applies an Empirical Cumulative Distribution Function (ECDF) | |
| to each feature and replaces feature values with their corresponding quantile | |
| ranks. The transformation is applied independently to each feature type. | |
| Both input and output are dictionaries mapping feature types to their | |
| corresponding feature matrices. | |
| Args: | |
| feature_keys (list[str]): Features to apply quantile creation to. | |
| Input: | |
| dict[str, np.ndarray]: A dictionary where each key corresponds to a feature type | |
| and each value is a 2D NumPy array of shape (n_samples, n_features). | |
| Output: | |
| dict[str, np.ndarray]: A dictionary with the same keys as the input, | |
| where each feature value is replaced by its corresponding ECDF quantile rank. | |
| """ | |
| def __init__(self, *, feature_keys=None): | |
| self._ecdfs = None | |
| super().__init__(feature_keys=feature_keys) | |
| def fit(self, X: dict[str, np.ndarray]): | |
| _X = self.dict_to_array(X) | |
| ecdfs = [] | |
| for column in range(_X.shape[1]): | |
| raw_values = _X[:, column].reshape(-1) | |
| ecdfs.append(ECDF(raw_values)) | |
| self._ecdfs = ecdfs | |
| self.is_fitted_ = True | |
| return self | |
| def transform(self, X: dict[str, np.ndarray]) -> np.ndarray: | |
| _X = self.dict_to_array(X) | |
| quantiles = np.zeros_like(_X) | |
| for column in range(_X.shape[1]): | |
| raw_values = _X[:, column].reshape(-1) | |
| ecdf = self._ecdfs[column] | |
| q = ecdf(raw_values) | |
| quantiles[:, column] = q | |
| return self.array_to_dict(quantiles) | |
| class FeaturePreprocessor(TransformerMixin, BaseEstimator): | |
| """This class implements the feature preprocessing from a dictionary of molecule features.""" | |
| def __init__( | |
| self, | |
| feature_selection_config: dict[str, Any], | |
| feature_quantilization_config: dict[str, Any], | |
| descriptors: list[str], | |
| max_samples: int = -1, | |
| scaler: str = "standard", | |
| ): | |
| self.descriptors = descriptors | |
| self.feature_quantilization_config = copy.deepcopy( | |
| feature_quantilization_config | |
| ) | |
| self.use_feat_quant = self.feature_quantilization_config.pop("use") | |
| self.quantile_creator = QuantileCreator(**self.feature_quantilization_config) | |
| self.feature_selection_config = copy.deepcopy(feature_selection_config) | |
| self.use_feat_selec = self.feature_selection_config.pop("use") | |
| self.feature_selection_config["feature_keys"] = descriptors | |
| self.feature_selector = FeatureSelector(**self.feature_selection_config) | |
| self.max_samples = max_samples | |
| self.sub_sampler = SubSampler(max_samples=max_samples) | |
| self.scaler = SCALER_REGISTRY[scaler]() | |
| def __getstate__(self): | |
| state = super().__getstate__() | |
| state["quantile_creator"] = self.quantile_creator.__getstate__() | |
| state["feature_selector"] = self.feature_selector.__getstate__() | |
| state["sub_sampler"] = self.sub_sampler.__getstate__() | |
| state["scaler"] = self.scaler.__getstate__() | |
| return state | |
| def __setstate__(self, state): | |
| _state = copy.deepcopy(state) | |
| self.quantile_creator.__setstate__(_state.pop("quantile_creator")) | |
| self.feature_selector.__setstate__(_state.pop("feature_selector")) | |
| self.sub_sampler.__setstate__(_state.pop("sub_sampler")) | |
| self.scaler.__setstate__(_state.pop("scaler")) | |
| super().__setstate__(_state) | |
| def get_state(self): | |
| return self.__getstate__() | |
| def set_state(self, state): | |
| return self.__setstate__(state) | |
| def fit(self, X: dict[str, np.ndarray]): | |
| """Fit the processor transformers""" | |
| _X = copy.deepcopy(X) | |
| if self.use_feat_quant: | |
| _X = self.quantile_creator.fit_transform(_X) | |
| if self.use_feat_selec: | |
| _X = self.feature_selector.fit_transform(_X) | |
| _X = np.concatenate([_X[descr] for descr in self.descriptors], axis=1) | |
| self.scaler.fit(_X) | |
| return self | |
| def transform( | |
| self, X: np.ndarray, y: np.ndarray | None = None | |
| ) -> np.ndarray | tuple[np.ndarray]: | |
| _X = X.copy() | |
| _y = y.copy() if y is not None else None | |
| if self.use_feat_quant: | |
| _X = self.quantile_creator.transform(_X) | |
| if self.use_feat_selec: | |
| _X = self.feature_selector.transform(_X) | |
| _X = np.concatenate([_X[descr] for descr in self.descriptors], axis=1) | |
| _X = self.scaler.transform(_X) | |
| if _y is None: | |
| _X = self.sub_sampler.transform(_X) | |
| return _X | |
| _X, _y = self.sub_sampler.transform(_X, _y) | |
| return _X, _y | |
| def create_cleaned_mol_objects(smiles: list[str]) -> tuple[list[Mol], np.ndarray]: | |
| """This function creates cleaned RDKit mol objects from a list of SMILES. | |
| Taken from https://huggingface.co/spaces/ml-jku/mhnfs/blob/main/src/data_preprocessing/create_descriptors.py | |
| Modification by Antonia Ebner: | |
| - skip uncleanable molecules | |
| - return clean molecule mask | |
| Args: | |
| smiles (list[str]): list of SMILES | |
| Returns: | |
| list[Mol]: list of cleaned molecules | |
| np.ndarray[bool]: mask that contains False at index `i`, if molecule in `smiles` at | |
| index `i` could not be cleaned and was removed. | |
| """ | |
| sm = Standardizer(canon_taut=True) | |
| clean_mol_mask = list() | |
| mols = list() | |
| for i, smile in enumerate(smiles): | |
| mol = Chem.MolFromSmiles(smile) | |
| standardized_mol, _ = sm.standardize_mol(mol) | |
| is_cleaned = standardized_mol is not None | |
| clean_mol_mask.append(is_cleaned) | |
| if not is_cleaned: | |
| continue | |
| can_mol = Chem.MolFromSmiles(Chem.MolToSmiles(standardized_mol)) | |
| mols.append(can_mol) | |
| return mols, np.array(clean_mol_mask) | |
| def create_ecfp_fps(mols: list[Mol], radius=3, fpsize=2048, **kwargs) -> np.ndarray: | |
| """This function ECFP fingerprints for a list of molecules. | |
| Inspired by from https://huggingface.co/spaces/ml-jku/mhnfs/blob/main/src/data_preprocessing/create_descriptors.py | |
| Args: | |
| mols (list[Mol]): list of molecules | |
| Returns: | |
| np.ndarray: ECFP fingerprints of molecules | |
| """ | |
| ecfps = list() | |
| for mol in mols: | |
| gen = rdFingerprintGenerator.GetMorganGenerator( | |
| countSimulation=True, fpSize=fpsize, radius=radius | |
| ) | |
| fp_sparse_vec = gen.GetCountFingerprint(mol) | |
| fp = np.zeros((0,), np.int8) | |
| DataStructs.ConvertToNumpyArray(fp_sparse_vec, fp) | |
| ecfps.append(fp) | |
| return np.array(ecfps) | |
| def create_maccs_keys(mols: list[Mol]) -> np.ndarray: | |
| """This function creates MACCS keys for a list of molecules. | |
| Args: | |
| mols (list[Mol]): list of molecules | |
| Returns: | |
| np.ndarray: MACCS keys of molecules | |
| """ | |
| maccs = [MACCSkeys.GenMACCSKeys(x) for x in mols] | |
| return np.array(maccs) | |
| def get_tox_patterns(filepath: str): | |
| """This retrieves the tox features defined in filepath. | |
| Args: | |
| filepath (str): A list of tox features | |
| """ | |
| # load patterns | |
| with open(filepath) as f: | |
| smarts_list = [s[1] for s in json.load(f)] | |
| # Code does not work for this case | |
| assert len([s for s in smarts_list if ("AND" in s) and ("OR" in s)]) == 0 | |
| # Chem.MolFromSmarts takes a long time so it pays of to parse all the smarts first | |
| # and then use them for all molecules. This gives a huge speedup over existing code. | |
| # a list of patterns, whether to negate the match result and how to join them to obtain one boolean value | |
| all_patterns = [] | |
| for smarts in smarts_list: | |
| patterns = [] # list of smarts-patterns | |
| # value for each of the patterns above. Negates the values of the above later. | |
| negations = [] | |
| if " AND " in smarts: | |
| smarts = smarts.split(" AND ") | |
| merge_any = False # If an ' AND ' is found all 'subsmarts' have to match | |
| else: | |
| # If there is an ' OR ' present it's enough is any of the 'subsmarts' match. | |
| # This also accumulates smarts where neither ' OR ' nor ' AND ' occur | |
| smarts = smarts.split(" OR ") | |
| merge_any = True | |
| # for all subsmarts check if they are preceded by 'NOT ' | |
| for s in smarts: | |
| neg = s.startswith("NOT ") | |
| if neg: | |
| s = s[4:] | |
| patterns.append(Chem.MolFromSmarts(s)) | |
| negations.append(neg) | |
| all_patterns.append((patterns, negations, merge_any)) | |
| return all_patterns | |
| def create_tox_features(mols: list[Mol], patterns: list) -> np.ndarray: | |
| """Matches the tox patterns against a molecule. Returns a boolean array""" | |
| tox_data = [] | |
| for mol in mols: | |
| mol_features = [] | |
| for patts, negations, merge_any in patterns: | |
| matches = [mol.HasSubstructMatch(p) for p in patts] | |
| matches = [m != n for m, n in zip(matches, negations)] | |
| if merge_any: | |
| pres = any(matches) | |
| else: | |
| pres = all(matches) | |
| mol_features.append(pres) | |
| tox_data.append(np.array(mol_features)) | |
| return np.array(tox_data) | |
| def create_rdkit_descriptors(mols: list[Mol]) -> np.ndarray: | |
| """This function creates RDKit descriptors for a list of molecules. | |
| Taken from https://huggingface.co/spaces/ml-jku/mhnfs/blob/main/src/data_preprocessing/create_descriptors.py | |
| Args: | |
| mols (list[Mol]): list of molecules | |
| Returns: | |
| np.ndarray: RDKit descriptors of molecules | |
| """ | |
| rdkit_descriptors = list() | |
| for mol in mols: | |
| descrs = [] | |
| for _, descr_calc_fn in Descriptors._descList: | |
| descrs.append(descr_calc_fn(mol)) | |
| descrs = np.array(descrs) | |
| descrs = descrs[USED_200_DESCR] | |
| rdkit_descriptors.append(descrs) | |
| return np.array(rdkit_descriptors) | |
| def create_quantiles(raw_features: np.ndarray, ecdfs: list) -> np.ndarray: | |
| """Create quantile values for given features using the columns | |
| Taken from https://huggingface.co/spaces/ml-jku/mhnfs/blob/main/src/data_preprocessing/create_descriptors.py | |
| Args: | |
| raw_features (np.ndarray): values to put into quantiles | |
| ecdfs (list): ECDFs to use | |
| Returns: | |
| np.ndarray: computed quantiles | |
| """ | |
| quantiles = np.zeros_like(raw_features) | |
| for column in range(raw_features.shape[1]): | |
| raw_values = raw_features[:, column].reshape(-1) | |
| ecdf = ecdfs[column] | |
| q = ecdf(raw_values) | |
| quantiles[:, column] = q | |
| return quantiles | |
| def fill(features, mask, value=np.nan): | |
| n_mols = len(mask) | |
| n_features = features.shape[1] | |
| data = np.zeros(shape=(n_mols, n_features)) | |
| data.fill(value) | |
| data[~mask] = features | |
| return data | |
| def create_descriptors( | |
| smiles, | |
| descriptors, | |
| **ecfp_kwargs, | |
| ): | |
| """Generate molecular descriptors for multiple SMILES strings. | |
| Inspired by https://huggingface.co/spaces/ml-jku/mhnfs/blob/main/src/data_preprocessing/create_descriptors.py | |
| Each SMILES is processed and sanitized using RDKit. | |
| SMILES that cannot be sanitized are encoded with NaNs, and a corresponding boolean mask | |
| is returned to indicate which inputs were successfully processed. | |
| Args: | |
| smiles (list[str]): List of SMILES strings for which to generate descriptors. | |
| descriptors (list[str]): List of descriptor types to compute. | |
| Supported values include: | |
| ['ecfps', 'tox', 'maccs', 'rdkit_descrs']. | |
| Returns: | |
| tuple[dict[str, np.ndarray], np.ndarray]: | |
| - A dictionary mapping descriptor names to their computed arrays. | |
| - A boolean mask of shape (len(smiles),) indicating which SMILES | |
| were successfully sanitized and processed. | |
| """ | |
| # Create cleanded rdkit mol objects | |
| mols, clean_mol_mask = create_cleaned_mol_objects(smiles) | |
| print(f"Cleaned molecules, {(~clean_mol_mask).sum()} could not be sanitized") | |
| # Create fingerprints and descriptors | |
| if "ecfps" in descriptors: | |
| ecfps = create_ecfp_fps(mols, **ecfp_kwargs) | |
| ecfps = fill(ecfps, ~clean_mol_mask) | |
| print("Created ECFP fingerprints") | |
| if "tox" in descriptors: | |
| tox_patterns = get_tox_patterns(TOX_SMARTS_PATH) | |
| tox = create_tox_features(mols, tox_patterns) | |
| tox = fill(tox, ~clean_mol_mask) | |
| print("Created Tox features") | |
| if "maccs" in descriptors: | |
| maccs = create_maccs_keys(mols) | |
| maccs = fill(maccs, ~clean_mol_mask) | |
| print("Created MACCS keys") | |
| if "rdkit_descrs" in descriptors: | |
| rdkit_descrs = create_rdkit_descriptors(mols) | |
| rdkit_descrs = fill(rdkit_descrs, ~clean_mol_mask) | |
| print("Created RDKit descriptors") | |
| # concatenate features | |
| features = {} | |
| for descr in descriptors: | |
| features[descr] = vars()[descr] | |
| return features, clean_mol_mask | |
| def get_tox21_split(token, cvfold=None): | |
| """Retrieve Tox21 splits from HuggingFace with respect to given cvfold.""" | |
| ds = load_dataset("ml-jku/tox21", token=token) | |
| train_df = ds["train"].to_pandas() | |
| val_df = ds["validation"].to_pandas() | |
| if cvfold is None: | |
| return {"train": train_df, "validation": val_df} | |
| combined_df = pd.concat([train_df, val_df], ignore_index=True) | |
| cvfold = float(cvfold) | |
| # create new splits | |
| cvfold = float(cvfold) | |
| train_df = combined_df[combined_df.CVfold != cvfold] | |
| val_df = combined_df[combined_df.CVfold == cvfold] | |
| # exclude train mols that occur in the validation split | |
| val_inchikeys = set(val_df["inchikey"]) | |
| train_df = train_df[~train_df["inchikey"].isin(val_inchikeys)] | |
| return { | |
| "train": train_df.reset_index(drop=True), | |
| "validation": val_df.reset_index(drop=True), | |
| } | |