import argparse import torch import numpy as np import re from transformers import AutoTokenizer, AutoModelForSequenceClassification DEFAULT_THRESHOLD = 0.5 def preprocess_text(text, anonymize_mentions=True): if anonymize_mentions: text = re.sub(r'@\w+', '@anonymized_account', text) return text def main(): parser = argparse.ArgumentParser() parser.add_argument("text", type=str, help="Text to classify") parser.add_argument("--model-path", type=str, default="yazoniak/twitter-emotion-pl-classifier", help="Path to model or HF model ID") parser.add_argument("--threshold", type=float, default=DEFAULT_THRESHOLD, help="Classification threshold (default: 0.5)") parser.add_argument("--no-anonymize", action="store_true", help="Disable mention anonymization (not recommended)") args = parser.parse_args() print(f"Loading model from: {args.model_path}") tokenizer = AutoTokenizer.from_pretrained(args.model_path) model = AutoModelForSequenceClassification.from_pretrained(args.model_path) model.eval() labels = [model.config.id2label[i] for i in range(model.config.num_labels)] anonymize = not args.no_anonymize processed_text = preprocess_text(args.text, anonymize_mentions=anonymize) if anonymize and processed_text != args.text: print(f"Preprocessed text: {processed_text}") print(f"\nInput text: {args.text}\n") inputs = tokenizer(processed_text, return_tensors="pt", truncation=True, max_length=8192) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits.squeeze().numpy() probabilities = 1 / (1 + np.exp(-logits)) predictions = probabilities > args.threshold assigned_labels = [labels[i] for i in range(len(labels)) if predictions[i]] if assigned_labels: print("Assigned Labels:") print("-" * 40) for label in assigned_labels: print(f" {label}") print() else: print("No labels assigned (all below threshold)\n") print("All Labels (with probabilities):") print("-" * 40) for i, label in enumerate(labels): status = "✓" if predictions[i] else " " print(f"{status} {label:15s}: {probabilities[i]:.4f}") if __name__ == "__main__": main()