Spaces:
Sleeping
Sleeping
Commit
·
75c7791
1
Parent(s):
117adda
debug predict.py
Browse files- predict.py +12 -7
predict.py
CHANGED
|
@@ -24,12 +24,14 @@ def predict(smiles_list: list[str]) -> dict:
|
|
| 24 |
Returns:
|
| 25 |
dict: nested prediction dictionary, following {'<smiles>': {'<target>': <pred>}}
|
| 26 |
"""
|
|
|
|
| 27 |
# preprocessing pipeline
|
| 28 |
features, removed_idxs = preprocess_molecules(
|
| 29 |
smiles_list,
|
| 30 |
load_ecdf_path="assets/ecdfs.pkl",
|
| 31 |
load_scaler_path="assets/scaler.pkl",
|
| 32 |
)
|
|
|
|
| 33 |
|
| 34 |
# setup model
|
| 35 |
model = Tox21RFClassifier(seed=42)
|
|
@@ -37,13 +39,16 @@ def predict(smiles_list: list[str]) -> dict:
|
|
| 37 |
|
| 38 |
# make predicitons
|
| 39 |
predictions = defaultdict(dict)
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
-
for
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
|
| 49 |
return predictions
|
|
|
|
| 24 |
Returns:
|
| 25 |
dict: nested prediction dictionary, following {'<smiles>': {'<target>': <pred>}}
|
| 26 |
"""
|
| 27 |
+
print(f"Received {len(smiles_list)} SMILES strings")
|
| 28 |
# preprocessing pipeline
|
| 29 |
features, removed_idxs = preprocess_molecules(
|
| 30 |
smiles_list,
|
| 31 |
load_ecdf_path="assets/ecdfs.pkl",
|
| 32 |
load_scaler_path="assets/scaler.pkl",
|
| 33 |
)
|
| 34 |
+
print(f"{len(removed_idxs)} molecules removed during cleaning")
|
| 35 |
|
| 36 |
# setup model
|
| 37 |
model = Tox21RFClassifier(seed=42)
|
|
|
|
| 39 |
|
| 40 |
# make predicitons
|
| 41 |
predictions = defaultdict(dict)
|
| 42 |
+
# make smiles list with same num_samples as features
|
| 43 |
+
clean_smiles = [smi for i, smi in enumerate(smiles_list) if i not in removed_idxs]
|
| 44 |
+
no_pred_smiles = [smi for i, smi in enumerate(smiles_list) if i in removed_idxs]
|
| 45 |
|
| 46 |
+
for target in model.tasks:
|
| 47 |
+
target_pred = model.predict(target, features)
|
| 48 |
+
for i, smiles in enumerate(clean_smiles):
|
| 49 |
+
predictions[smiles][target] = target_pred[i]
|
| 50 |
+
|
| 51 |
+
for smiles in no_pred_smiles:
|
| 52 |
+
predictions[smiles][target] = 0.0
|
| 53 |
|
| 54 |
return predictions
|