angelperedo01 commited on
Commit
0bcf758
·
verified ·
1 Parent(s): 68c0f61

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -38
app.py CHANGED
@@ -9,16 +9,15 @@ import numpy as np
9
  # REPLACE THIS WITH YOUR UPLOADED MODEL NAME!
10
  MODEL_REPO = "angelperedo01/proj2"
11
  DATASET_NAME = "nvidia/Aegis-AI-Content-Safety-Dataset-2.0"
12
- MAX_SAMPLES = 200 # Limit samples for the demo so it doesn't take hours
13
 
14
  def get_text_and_label(example):
15
  """
16
- Your custom logic to parse the NVIDIA dataset labels.
17
  """
18
  text = example.get('prompt', '')
19
  label = None
20
 
21
- # Try 'prompt_label' first
22
  if 'prompt_label' in example:
23
  raw_label = example['prompt_label']
24
  if isinstance(raw_label, str):
@@ -33,13 +32,12 @@ def get_text_and_label(example):
33
  else:
34
  label = int(raw_label)
35
 
36
- # Default to Safe (0) if we really can't find it
37
  if label is None: label = 0
38
  return text, label
39
 
40
- def run_live_evaluation(progress=gr.Progress()):
41
  # 1. Load Model & Data
42
- yield "Loading Model from Hub...", "-", "-", []
43
 
44
  try:
45
  tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
@@ -48,81 +46,106 @@ def run_live_evaluation(progress=gr.Progress()):
48
  model.to(device)
49
  model.eval()
50
  except Exception as e:
51
- yield f"Error loading model: {str(e)}", "Error", "Error", []
52
  return
53
 
54
- # Load Dataset (Streaming or small slice for speed)
55
- yield "Loading NVIDIA Dataset...", "-", "-", []
56
  try:
57
- # Try test split, fallback to train
58
  ds = load_dataset(DATASET_NAME, split="test")
59
  except:
60
  ds = load_dataset(DATASET_NAME, split="train")
61
 
62
- # Shuffle and select subset for the demo
63
  ds = ds.shuffle(seed=42).select(range(MAX_SAMPLES))
64
 
65
  true_labels = []
66
  predictions = []
67
- logs = [] # To store misclassifications
 
 
 
 
68
 
69
  # 2. The Evaluation Loop
70
- for i, item in enumerate(progress.tqdm(ds, desc="Evaluating...")):
 
71
  text, true_label = get_text_and_label(item)
72
  true_labels.append(true_label)
73
 
74
- # Tokenize
75
- inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256).to(device)
76
-
77
  # Predict
 
78
  with torch.no_grad():
79
  logits = model(**inputs).logits
80
  pred = torch.argmax(logits, dim=-1).item()
81
  predictions.append(pred)
82
 
83
- # Log Errors (If prediction is wrong)
84
- if pred != true_label:
85
- status = "🔴 MISS"
86
- logs.insert(0, [status, text[:80] + "...", "Safe" if true_label==0 else "Unsafe", "Safe" if pred==0 else "Unsafe"])
 
 
 
 
 
 
87
  else:
88
- # Optional: Log successes too if you want, but it clutters the view
89
- pass
90
 
91
- # Update UI every 5 steps
92
- if i % 5 == 0 or i == len(ds)-1:
93
  acc = accuracy_score(true_labels, predictions)
94
  f1 = f1_score(true_labels, predictions, zero_division=0)
95
-
96
- status_msg = f"Processed {i+1}/{MAX_SAMPLES}"
97
- yield status_msg, f"{acc:.2%}", f"{f1:.2f}", logs[:10] # Show last 10 errors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  # --- UI LAYOUT ---
100
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
101
- gr.Markdown(f"## 🛡️ Live Safety Model Evaluation")
102
- gr.Markdown(f"Running `{MODEL_REPO}` on `{DATASET_NAME}` (Live Inference)")
103
 
104
  with gr.Row():
105
- start_btn = gr.Button("▶️ Start Live Test", variant="primary")
106
 
107
  with gr.Row():
108
  with gr.Column():
109
  status_box = gr.Label(value="Ready", label="Status")
110
  with gr.Column():
111
- acc_box = gr.Label(value="-", label="Current Accuracy")
112
  with gr.Column():
113
- f1_box = gr.Label(value="-", label="Current F1 Score")
114
 
115
- gr.Markdown("### 🚨 Recent Misclassifications (Live Feed)")
116
- log_table = gr.Dataframe(
117
- headers=["Status", "Text Snippet", "True Label", "Predicted"],
 
 
 
118
  datatype=["str", "str", "str", "str"],
119
- row_count=10
120
  )
121
 
122
  start_btn.click(
123
- fn=run_live_evaluation,
124
  inputs=None,
125
- outputs=[status_box, acc_box, f1_box, log_table]
126
  )
127
 
128
  demo.queue().launch()
 
9
  # REPLACE THIS WITH YOUR UPLOADED MODEL NAME!
10
  MODEL_REPO = "angelperedo01/proj2"
11
  DATASET_NAME = "nvidia/Aegis-AI-Content-Safety-Dataset-2.0"
12
+ MAX_SAMPLES = 300 # Increased slightly since we aren't rendering the table live
13
 
14
  def get_text_and_label(example):
15
  """
16
+ Parses the NVIDIA dataset labels.
17
  """
18
  text = example.get('prompt', '')
19
  label = None
20
 
 
21
  if 'prompt_label' in example:
22
  raw_label = example['prompt_label']
23
  if isinstance(raw_label, str):
 
32
  else:
33
  label = int(raw_label)
34
 
 
35
  if label is None: label = 0
36
  return text, label
37
 
38
+ def run_evaluation(progress=gr.Progress()):
39
  # 1. Load Model & Data
40
+ yield "Loading Model...", "-", "-", []
41
 
42
  try:
43
  tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
 
46
  model.to(device)
47
  model.eval()
48
  except Exception as e:
49
+ yield f"Error: {str(e)}", "Error", "Error", []
50
  return
51
 
52
+ yield "Loading Dataset...", "-", "-", []
 
53
  try:
 
54
  ds = load_dataset(DATASET_NAME, split="test")
55
  except:
56
  ds = load_dataset(DATASET_NAME, split="train")
57
 
58
+ # Shuffle and select subset
59
  ds = ds.shuffle(seed=42).select(range(MAX_SAMPLES))
60
 
61
  true_labels = []
62
  predictions = []
63
+
64
+ # Store full details to filter later
65
+ # Structure: [Status, Text, True, Pred]
66
+ history_correct = []
67
+ history_incorrect = []
68
 
69
  # 2. The Evaluation Loop
70
+ # We yield updates less frequently to prevent UI flashing
71
+ for i, item in enumerate(progress.tqdm(ds, desc="Classifying...")):
72
  text, true_label = get_text_and_label(item)
73
  true_labels.append(true_label)
74
 
 
 
 
75
  # Predict
76
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256).to(device)
77
  with torch.no_grad():
78
  logits = model(**inputs).logits
79
  pred = torch.argmax(logits, dim=-1).item()
80
  predictions.append(pred)
81
 
82
+ # Store for final report
83
+ label_map = {0: "Safe", 1: "Unsafe"}
84
+ entry = [
85
+ text,
86
+ label_map[true_label],
87
+ label_map[pred]
88
+ ]
89
+
90
+ if pred == true_label:
91
+ history_correct.append(["✅ Correct"] + entry)
92
  else:
93
+ history_incorrect.append(["🔴 WRONG"] + entry)
 
94
 
95
+ # Update metrics every 10 steps (Reduces flashing)
96
+ if i % 10 == 0:
97
  acc = accuracy_score(true_labels, predictions)
98
  f1 = f1_score(true_labels, predictions, zero_division=0)
99
+ # Yield empty list for table so it doesn't try to render anything yet
100
+ yield f"Processed {i+1}/{MAX_SAMPLES}", f"{acc:.2%}", f"{f1:.2f}", []
101
+
102
+ # 3. Final Compilation
103
+ # Grab last 10 incorrect and last 10 correct
104
+ final_display_data = []
105
+
106
+ # Add header/separator logic if you want, or just mix them
107
+ # We prioritize showing errors first
108
+ if history_incorrect:
109
+ final_display_data.extend(history_incorrect[-10:]) # Last 10 errors
110
+
111
+ if history_correct:
112
+ final_display_data.extend(history_correct[-10:]) # Last 10 correct
113
+
114
+ final_acc = accuracy_score(true_labels, predictions)
115
+ final_f1 = f1_score(true_labels, predictions, zero_division=0)
116
+
117
+ yield "Evaluation Complete!", f"{final_acc:.2%}", f"{final_f1:.2f}", final_display_data
118
 
119
  # --- UI LAYOUT ---
120
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
121
+ gr.Markdown(f"## 🛡️ Model Safety Evaluation Dashboard")
122
+ gr.Markdown(f"Testing `{MODEL_REPO}` on `{DATASET_NAME}`")
123
 
124
  with gr.Row():
125
+ start_btn = gr.Button("▶️ Run Live Test", variant="primary")
126
 
127
  with gr.Row():
128
  with gr.Column():
129
  status_box = gr.Label(value="Ready", label="Status")
130
  with gr.Column():
131
+ acc_box = gr.Label(value="-", label="Accuracy")
132
  with gr.Column():
133
+ f1_box = gr.Label(value="-", label="F1 Score")
134
 
135
+ gr.Markdown("### 📝 Final Report: Sample of Results")
136
+ gr.Markdown("*(Showing last 10 Incorrect and last 10 Correct predictions)*")
137
+
138
+ # Defined table but it stays empty until the end
139
+ result_table = gr.Dataframe(
140
+ headers=["Result", "Text Snippet", "True Label", "Predicted"],
141
  datatype=["str", "str", "str", "str"],
142
+ wrap=True
143
  )
144
 
145
  start_btn.click(
146
+ fn=run_evaluation,
147
  inputs=None,
148
+ outputs=[status_box, acc_box, f1_box, result_table]
149
  )
150
 
151
  demo.queue().launch()