import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from transformers.utils import logging as hf_logging MODEL_ID = "google/t5_xxl_true_nli_mixture" hf_logging.set_verbosity_info() tok = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForSeq2SeqLM.from_pretrained( MODEL_ID, use_safetensors=True, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, device_map="auto", ).eval() # cache token ids for speed ID0 = tok.encode("0", add_special_tokens=False)[0] ID1 = tok.encode("1", add_special_tokens=False)[0] DEC_START = model.config.decoder_start_token_id # --- Endpoint 1: logits + threshold --- def score_threshold(premise: str, hypothesis: str, threshold: str = "0.5"): # parse & clamp threshold from textbox try: thr = float(threshold) except Exception: thr = 0.5 thr = max(0.0, min(1.0, thr)) text = f"premise: {premise} hypothesis: {hypothesis}" enc = tok(text, return_tensors="pt").to(model.device) dec_in = torch.tensor([[DEC_START]], device=model.device) with torch.no_grad(): logits = model(**enc, decoder_input_ids=dec_in).logits[0, -1] p = torch.softmax(logits.float(), dim=-1) p1 = float(p[ID1]) label = "1" if p1 >= thr else "0" return {"label": label, "p1": p1, "threshold": thr} # --- Endpoint 2: FIXED greedy decode --- MAX_NEW_TOKENS = 1 # set to 1 if you want stricter single-token behavior def score_greedy(premise: str, hypothesis: str): text = f"premise: {premise} hypothesis: {hypothesis}" enc = tok(text, return_tensors="pt").to(model.device) with torch.no_grad(): out_ids = model.generate( **enc, max_new_tokens=MAX_NEW_TOKENS, do_sample=False, num_beams=1 ) # FIX: Decode only the new tokens, not the entire sequence input_length = enc['input_ids'].shape[1] new_tokens = out_ids[0][input_length:] # Only the generated tokens label = tok.decode(new_tokens, skip_special_tokens=True).strip() # Additional safety check if label not in ["0", "1"]: # Fallback: use the same logic as threshold endpoint dec_in = torch.tensor([[DEC_START]], device=model.device) logits = model(**enc, decoder_input_ids=dec_in).logits[0, -1] p = torch.softmax(logits.float(), dim=-1) p1 = float(p[ID1]) label = "1" if p1 >= 0.5 else "0" return {"label": label} # Interfaces with distinct API names iface_thresh = gr.Interface( fn=score_threshold, inputs=[gr.Textbox(label="Premise"), gr.Textbox(label="Hypothesis"), gr.Textbox(label="Decision threshold (0–1)", value="0.5")], outputs=gr.JSON(label="Prediction"), title="T5-XXL TRUE NLI (logits + threshold)", description="Returns '1' if p1 ≥ threshold; also returns p1.", api_name="predict_threshold", ) iface_greedy = gr.Interface( fn=score_greedy, inputs=[gr.Textbox(label="Premise"), gr.Textbox(label="Hypothesis")], outputs=gr.JSON(label="Prediction"), title="T5-XXL TRUE NLI (greedy decode)", description=f"Greedy decode with max_new_tokens={MAX_NEW_TOKENS}.", api_name="predict_greedy", ) demo = gr.TabbedInterface([iface_thresh, iface_greedy], tab_names=["Threshold", "Greedy"]) if __name__ == "__main__": demo.queue(default_concurrency_limit=1).launch(show_error=True)