Upload folder using huggingface_hub
Browse files
README.md
CHANGED
|
@@ -103,8 +103,12 @@ class ASJCMultiLabelPipeline(TextClassificationPipeline):
|
|
| 103 |
def postprocess(self, model_outputs, **kwargs):
|
| 104 |
# Convert logits to probabilities using sigmoid
|
| 105 |
scores = torch.sigmoid(torch.tensor(model_outputs["logits"])).tolist()
|
| 106 |
-
results = []
|
| 107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
for i, score in enumerate(scores):
|
| 109 |
if score >= self.threshold:
|
| 110 |
label = self.model.config.id2label[str(i)]
|
|
|
|
| 103 |
def postprocess(self, model_outputs, **kwargs):
|
| 104 |
# Convert logits to probabilities using sigmoid
|
| 105 |
scores = torch.sigmoid(torch.tensor(model_outputs["logits"])).tolist()
|
|
|
|
| 106 |
|
| 107 |
+
# Handle single-input batch: take the first element
|
| 108 |
+
if isinstance(scores[0], list):
|
| 109 |
+
scores = scores[0]
|
| 110 |
+
|
| 111 |
+
results = []
|
| 112 |
for i, score in enumerate(scores):
|
| 113 |
if score >= self.threshold:
|
| 114 |
label = self.model.config.id2label[str(i)]
|