grano1 commited on
Commit
562f69d
·
verified ·
1 Parent(s): af63e06

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +5 -1
README.md CHANGED
@@ -103,8 +103,12 @@ class ASJCMultiLabelPipeline(TextClassificationPipeline):
103
  def postprocess(self, model_outputs, **kwargs):
104
  # Convert logits to probabilities using sigmoid
105
  scores = torch.sigmoid(torch.tensor(model_outputs["logits"])).tolist()
106
- results = []
107
 
 
 
 
 
 
108
  for i, score in enumerate(scores):
109
  if score >= self.threshold:
110
  label = self.model.config.id2label[str(i)]
 
103
  def postprocess(self, model_outputs, **kwargs):
104
  # Convert logits to probabilities using sigmoid
105
  scores = torch.sigmoid(torch.tensor(model_outputs["logits"])).tolist()
 
106
 
107
+ # Handle single-input batch: take the first element
108
+ if isinstance(scores[0], list):
109
+ scores = scores[0]
110
+
111
+ results = []
112
  for i, score in enumerate(scores):
113
  if score >= self.threshold:
114
  label = self.model.config.id2label[str(i)]