import argparse import json import os import re import torch import numpy as np from transformers import AutoTokenizer, AutoModelForSequenceClassification def preprocess_text(text, anonymize_mentions=True): if anonymize_mentions: text = re.sub(r'@\w+', '@anonymized_account', text) return text def load_calibration_artifacts(calib_path): if not os.path.exists(calib_path): raise FileNotFoundError(f"Calibration artifacts not found at: {calib_path}") with open(calib_path, 'r') as f: return json.load(f) def main(): parser = argparse.ArgumentParser() parser.add_argument("text", type=str, help="Text to classify") parser.add_argument("--model-path", type=str, default="yazoniak/twitter-emotion-pl-classifier", help="Path to model or HF model ID") parser.add_argument("--calibration-path", type=str, default=None, help="Path to calibration_artifacts.json (default: auto-detect)") parser.add_argument("--no-anonymize", action="store_true", help="Disable mention anonymization (not recommended)") args = parser.parse_args() print(f"Loading model from: {args.model_path}") tokenizer = AutoTokenizer.from_pretrained(args.model_path) model = AutoModelForSequenceClassification.from_pretrained(args.model_path) model.eval() labels = [model.config.id2label[i] for i in range(model.config.num_labels)] anonymize = not args.no_anonymize processed_text = preprocess_text(args.text, anonymize_mentions=anonymize) if anonymize and processed_text != args.text: print(f"Preprocessed text: {processed_text}") if args.calibration_path: calib_path = args.calibration_path elif os.path.isdir(args.model_path): calib_path = os.path.join(args.model_path, "calibration_artifacts.json") else: calib_path = os.path.join(os.path.dirname(__file__), "calibration_artifacts.json") print(f"Loading calibration from: {calib_path}") calib_artifacts = load_calibration_artifacts(calib_path) temperatures = calib_artifacts["temperatures"] optimal_thresholds = calib_artifacts["optimal_thresholds"] print(f"\nInput text: {args.text}\n") inputs = tokenizer(processed_text, return_tensors="pt", truncation=True, max_length=8192) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits.squeeze().numpy() calibrated_probs = [] predictions = [] for i, label in enumerate(labels): temp = temperatures[label] threshold = optimal_thresholds[label] calibrated_logit = logits[i] / temp prob = 1 / (1 + np.exp(-calibrated_logit)) calibrated_probs.append(prob) predictions.append(prob > threshold) assigned_labels = [labels[i] for i in range(len(labels)) if predictions[i]] if assigned_labels: print("Assigned Labels (Calibrated):") print("-" * 40) for label in assigned_labels: print(f" {label}") print() else: print("No labels assigned (all below optimal thresholds)\n") print("All Labels (with calibrated probabilities):") print("-" * 40) for i, label in enumerate(labels): status = "✓" if predictions[i] else " " threshold = optimal_thresholds[label] print(f"{status} {label:15s}: {calibrated_probs[i]:.4f} (threshold: {threshold:.2f})") if __name__ == "__main__": main()