grano1 commited on
Commit
683a09c
·
verified ·
1 Parent(s): e2d52cc

Delete custom_pipeline.py

Browse files
Files changed (1) hide show
  1. custom_pipeline.py +0 -32
custom_pipeline.py DELETED
@@ -1,32 +0,0 @@
1
- '''
2
- # @ Author: ASJC Team
3
- # @ Create Time: 2025-11-25 23:52:55
4
- # @ Modified by: ASJC Team
5
- # @ Modified time: 2025-11-25 23:53:08
6
- # @ Description: Custom pipeline for multi-label classification with fine-tuned SciBERT model.
7
- '''
8
-
9
- from transformers import TextClassificationPipeline
10
- import torch
11
-
12
- class ASJCMultiLabelPipeline(TextClassificationPipeline):
13
- def __init__(self, *args, **kwargs):
14
- self.threshold = kwargs.pop("threshold", None)
15
- super().__init__(*args, **kwargs)
16
-
17
- # If no explicit threshold passed → use threshold from config.json
18
- if self.threshold is None:
19
- self.threshold = getattr(self.model.config, "threshold", 0.3)
20
-
21
- def postprocess(self, model_outputs, **kwargs):
22
- scores = torch.sigmoid(torch.tensor(model_outputs["logits"])).tolist()
23
- results = []
24
-
25
- for i, score in enumerate(scores):
26
- if score >= self.threshold:
27
- label = self.model.config.id2label[str(i)]
28
- results.append({"label": label, "score": float(score)})
29
-
30
- # Sort by score descending
31
- results = sorted(results, key=lambda x: x["score"], reverse=True)
32
- return results