0xnu commited on
Commit
44240fa
·
verified ·
1 Parent(s): 9a2525c

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +558 -0
  2. config.py +192 -0
app.py ADDED
@@ -0,0 +1,558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ from datasets import load_dataset
4
+ import time
5
+ from typing import Dict, List, Tuple
6
+ from config import ModelManager
7
+
8
+ class MathsBenchmarkApp:
9
+ def __init__(self):
10
+ """Initialise the Mathematics Benchmark application."""
11
+ self.dataset = None
12
+ self.df = None
13
+ self.model_manager = ModelManager()
14
+ self.load_dataset()
15
+
16
+ def load_dataset(self) -> None:
17
+ """Load the MathsBench dataset from HuggingFace."""
18
+ try:
19
+ self.dataset = load_dataset("0xnu/maths_bench", split="train")
20
+ self.df = pd.DataFrame(self.dataset)
21
+ print(f"Dataset loaded successfully: {len(self.df)} questions")
22
+ except Exception as e:
23
+ print(f"Error loading dataset: {e}")
24
+ self.df = pd.DataFrame()
25
+
26
+ def setup_api_provider(self, provider_name: str, api_key: str) -> Tuple[bool, str]:
27
+ """Setup API provider with key."""
28
+ return self.model_manager.setup_provider(provider_name, api_key)
29
+
30
+ def get_filtered_data(self, category: str = "All", difficulty: str = "All") -> pd.DataFrame:
31
+ """Filter dataset based on category and difficulty."""
32
+ if self.df.empty:
33
+ return pd.DataFrame()
34
+
35
+ filtered_df = self.df.copy()
36
+
37
+ if category != "All":
38
+ filtered_df = filtered_df[filtered_df['category'] == category]
39
+
40
+ if difficulty != "All":
41
+ filtered_df = filtered_df[filtered_df['difficulty'] == difficulty]
42
+
43
+ return filtered_df
44
+
45
+ def create_prompt_for_question(self, question_data: Dict) -> str:
46
+ """Create a structured prompt for the model."""
47
+ prompt = f"""You are an expert mathematician. Solve this question and select the correct answer from the given options.
48
+
49
+ Question: {question_data['question']}
50
+
51
+ Available options:
52
+ A) {question_data['option_a']}
53
+ B) {question_data['option_b']}
54
+ C) {question_data['option_c']}
55
+ D) {question_data['option_d']}
56
+
57
+ Instructions:
58
+ 1. Work through the problem step by step
59
+ 2. Compare your result with each option
60
+ 3. Select the option that matches your calculated answer
61
+ 4. Respond with only the letter of your chosen answer
62
+
63
+ Your response must end with: "My final answer is: [LETTER]"
64
+
65
+ Example format:
66
+ First I'll solve... [working]
67
+ Checking the options...
68
+ My final answer is: B"""
69
+
70
+ return prompt
71
+
72
+ def evaluate_single_question(self, question_id: int, model: str) -> Dict:
73
+ """Evaluate a single question using the specified model."""
74
+ if not self.model_manager.get_configured_providers():
75
+ return {"error": "No API providers configured"}
76
+
77
+ question_data = self.df[self.df['question_id'] == question_id].iloc[0].to_dict()
78
+ prompt = self.create_prompt_for_question(question_data)
79
+
80
+ try:
81
+ ai_response = self.model_manager.generate_response(prompt, model, max_tokens=800)
82
+
83
+ # Parse the response to extract the answer
84
+ ai_answer = self.extract_answer_from_response(ai_response)
85
+
86
+ # Convert correct answer to letter format if needed
87
+ correct_answer_letter = self.convert_answer_to_letter(question_data)
88
+
89
+ is_correct = ai_answer == correct_answer_letter
90
+
91
+ return {
92
+ "question_id": question_id,
93
+ "question": question_data['question'],
94
+ "category": question_data['category'],
95
+ "difficulty": question_data['difficulty'],
96
+ "correct_answer": question_data['correct_answer'],
97
+ "correct_answer_letter": correct_answer_letter,
98
+ "ai_answer": ai_answer,
99
+ "is_correct": is_correct,
100
+ "ai_response": ai_response,
101
+ "model": model,
102
+ "options": {
103
+ "A": question_data['option_a'],
104
+ "B": question_data['option_b'],
105
+ "C": question_data['option_c'],
106
+ "D": question_data['option_d']
107
+ }
108
+ }
109
+ except Exception as e:
110
+ return {"error": f"API call failed: {str(e)}"}
111
+
112
+ def convert_answer_to_letter(self, question_data: Dict) -> str:
113
+ """Convert the correct answer to its corresponding letter option."""
114
+ correct_answer = str(question_data['correct_answer']).strip()
115
+
116
+ options = {
117
+ 'A': str(question_data['option_a']).strip(),
118
+ 'B': str(question_data['option_b']).strip(),
119
+ 'C': str(question_data['option_c']).strip(),
120
+ 'D': str(question_data['option_d']).strip()
121
+ }
122
+
123
+ # Find which option matches the correct answer
124
+ for letter, option_value in options.items():
125
+ if correct_answer == option_value:
126
+ return letter
127
+
128
+ # If no exact match, try case-insensitive comparison
129
+ correct_lower = correct_answer.lower()
130
+ for letter, option_value in options.items():
131
+ if correct_lower == option_value.lower():
132
+ return letter
133
+
134
+ # If still no match, return the first option as fallback
135
+ return 'A'
136
+
137
+ def extract_answer_from_response(self, response: str) -> str:
138
+ """Extract the letter answer from AI response."""
139
+ response_upper = response.upper()
140
+
141
+ # Primary method: Look for "MY FINAL ANSWER IS: X" pattern
142
+ if "MY FINAL ANSWER IS:" in response_upper:
143
+ answer_part = response_upper.split("MY FINAL ANSWER IS:")[1].strip()
144
+ for letter in ['A', 'B', 'C', 'D']:
145
+ if letter in answer_part[:3]: # Check first 3 chars after the phrase
146
+ return letter
147
+
148
+ # Secondary method: Look for "ANSWER:" pattern
149
+ if "ANSWER:" in response_upper:
150
+ answer_part = response_upper.split("ANSWER:")[1].strip()
151
+ for letter in ['A', 'B', 'C', 'D']:
152
+ if letter in answer_part[:10]:
153
+ return letter
154
+
155
+ # Tertiary method: Look for explicit statements like "THE ANSWER IS A"
156
+ for letter in ['A', 'B', 'C', 'D']:
157
+ patterns = [
158
+ f"THE ANSWER IS {letter}",
159
+ f"ANSWER IS {letter}",
160
+ f"I CHOOSE {letter}",
161
+ f"SELECT {letter}",
162
+ f"OPTION {letter}"
163
+ ]
164
+ for pattern in patterns:
165
+ if pattern in response_upper:
166
+ return letter
167
+
168
+ # Final fallback: Look for last occurrence of a standalone letter
169
+ letters_found = []
170
+ for letter in ['A', 'B', 'C', 'D']:
171
+ if f" {letter}" in response_upper or f"{letter})" in response_upper or f"({letter}" in response_upper:
172
+ letters_found.append(letter)
173
+
174
+ if letters_found:
175
+ return letters_found[-1] # Return the last found letter
176
+
177
+ return "Unknown"
178
+
179
+ def run_benchmark(self, category: str, difficulty: str, num_questions: int, model: str, progress=gr.Progress()) -> Tuple[pd.DataFrame, str]:
180
+ """Run benchmark evaluation on filtered questions."""
181
+ if not self.model_manager.get_configured_providers():
182
+ return pd.DataFrame(), "Please configure API providers first"
183
+
184
+ filtered_df = self.get_filtered_data(category, difficulty)
185
+
186
+ if filtered_df.empty:
187
+ return pd.DataFrame(), "No questions found for the selected filters"
188
+
189
+ # Sample questions if requested number is less than available
190
+ if num_questions < len(filtered_df):
191
+ filtered_df = filtered_df.sample(n=num_questions, random_state=42)
192
+
193
+ results = []
194
+ correct_count = 0
195
+
196
+ progress(0, desc="Starting evaluation...")
197
+
198
+ for i, (_, row) in enumerate(filtered_df.iterrows()):
199
+ progress((i + 1) / len(filtered_df), desc=f"Evaluating question {i + 1}/{len(filtered_df)}")
200
+
201
+ result = self.evaluate_single_question(row['question_id'], model)
202
+
203
+ if "error" not in result:
204
+ results.append(result)
205
+ if result['is_correct']:
206
+ correct_count += 1
207
+
208
+ # Add small delay to avoid rate limits
209
+ time.sleep(0.5)
210
+
211
+ if not results:
212
+ return pd.DataFrame(), "No valid results obtained"
213
+
214
+ results_df = pd.DataFrame(results)
215
+ accuracy = (correct_count / len(results)) * 100
216
+
217
+ summary = f"""
218
+ Benchmark Complete!
219
+
220
+ Total Questions: {len(results)}
221
+ Correct Answers: {correct_count}
222
+ Accuracy: {accuracy:.2f}%
223
+ Model: {model}
224
+ Category: {category}
225
+ Difficulty: {difficulty}
226
+ """
227
+
228
+ return results_df, summary
229
+
230
+ # Global app instance
231
+ app = MathsBenchmarkApp()
232
+
233
+ def create_gradio_interface():
234
+ """Create the Gradio interface for the Mathematics Benchmark."""
235
+
236
+ # Get unique categories and difficulties
237
+ categories = ["All"] + sorted(app.df['category'].unique().tolist()) if not app.df.empty else ["All"]
238
+ difficulties = ["All"] + sorted(app.df['difficulty'].unique().tolist()) if not app.df.empty else ["All"]
239
+
240
+ with gr.Blocks(title="Mathematics Benchmark", theme=gr.themes.Soft()) as interface:
241
+ gr.HTML("""
242
+ <div style="text-align: center; padding: 20px;">
243
+ <h1>🧮 LLM Mathematics Benchmark</h1>
244
+ <p>Evaluate Large Language Models on mathematical reasoning tasks using a diverse dataset of questions</p>
245
+ </div>
246
+ """)
247
+
248
+ with gr.Tab("🔧 Configuration"):
249
+ gr.HTML("<h3>API Configuration</h3><p>Configure your API keys for different model providers:</p>")
250
+
251
+ # OpenAI Configuration
252
+ with gr.Group():
253
+ gr.HTML("<h4>🤖 OpenAI Configuration</h4>")
254
+ with gr.Row():
255
+ openai_key_input = gr.Textbox(
256
+ label="OpenAI API Key",
257
+ placeholder="Enter your OpenAI API key",
258
+ type="password",
259
+ scale=3
260
+ )
261
+ openai_setup_btn = gr.Button("Configure OpenAI", variant="primary", scale=1)
262
+
263
+ openai_status = gr.Textbox(label="OpenAI Status", interactive=False)
264
+
265
+ # Claude Configuration
266
+ with gr.Group():
267
+ gr.HTML("<h4>🧠 Anthropic Claude Configuration</h4>")
268
+ with gr.Row():
269
+ claude_key_input = gr.Textbox(
270
+ label="Anthropic API Key",
271
+ placeholder="Enter your Anthropic API key",
272
+ type="password",
273
+ scale=3
274
+ )
275
+ claude_setup_btn = gr.Button("Configure Claude", variant="primary", scale=1)
276
+
277
+ claude_status = gr.Textbox(label="Claude Status", interactive=False)
278
+
279
+ # Configuration status
280
+ config_summary = gr.Textbox(
281
+ label="Configuration Summary",
282
+ placeholder="No providers configured",
283
+ interactive=False
284
+ )
285
+
286
+ def setup_openai(api_key):
287
+ success, message = app.setup_api_provider("openai", api_key)
288
+ update_config_summary()
289
+ return message
290
+
291
+ def setup_claude(api_key):
292
+ success, message = app.setup_api_provider("claude", api_key)
293
+ update_config_summary()
294
+ return message
295
+
296
+ def update_config_summary():
297
+ configured = app.model_manager.get_configured_providers()
298
+ if not configured:
299
+ return "No providers configured"
300
+ return f"Configured providers: {', '.join(configured)}"
301
+
302
+ openai_setup_btn.click(
303
+ fn=setup_openai,
304
+ inputs=[openai_key_input],
305
+ outputs=[openai_status]
306
+ )
307
+
308
+ claude_setup_btn.click(
309
+ fn=setup_claude,
310
+ inputs=[claude_key_input],
311
+ outputs=[claude_status]
312
+ )
313
+
314
+ with gr.Tab("📊 Dataset Explorer"):
315
+ with gr.Row():
316
+ filter_category = gr.Dropdown(
317
+ choices=categories,
318
+ value="All",
319
+ label="Category",
320
+ scale=1
321
+ )
322
+ filter_difficulty = gr.Dropdown(
323
+ choices=difficulties,
324
+ value="All",
325
+ label="Difficulty",
326
+ scale=1
327
+ )
328
+ refresh_btn = gr.Button("Refresh Data", scale=1)
329
+
330
+ dataset_table = gr.Dataframe(
331
+ headers=["question_id", "category", "difficulty", "question", "correct_answer"],
332
+ label="Filtered Dataset"
333
+ )
334
+
335
+ def update_table(category, difficulty):
336
+ filtered_df = app.get_filtered_data(category, difficulty)
337
+ if filtered_df.empty:
338
+ return pd.DataFrame()
339
+ return filtered_df[['question_id', 'category', 'difficulty', 'question', 'correct_answer']]
340
+
341
+ refresh_btn.click(
342
+ fn=update_table,
343
+ inputs=[filter_category, filter_difficulty],
344
+ outputs=[dataset_table]
345
+ )
346
+
347
+ # Initial load
348
+ interface.load(
349
+ fn=update_table,
350
+ inputs=[filter_category, filter_difficulty],
351
+ outputs=[dataset_table]
352
+ )
353
+
354
+ with gr.Tab("🧪 Run Benchmark"):
355
+ with gr.Row():
356
+ bench_category = gr.Dropdown(
357
+ choices=categories,
358
+ value="All",
359
+ label="Category Filter"
360
+ )
361
+ bench_difficulty = gr.Dropdown(
362
+ choices=difficulties,
363
+ value="All",
364
+ label="Difficulty Filter"
365
+ )
366
+
367
+ with gr.Row():
368
+ num_questions = gr.Slider(
369
+ minimum=1,
370
+ maximum=100,
371
+ value=10,
372
+ step=1,
373
+ label="Number of Questions"
374
+ )
375
+ model_choice = gr.Dropdown(
376
+ choices=app.model_manager.get_flat_model_list(),
377
+ value=app.model_manager.get_flat_model_list()[0] if app.model_manager.get_flat_model_list() else None,
378
+ label="Model"
379
+ )
380
+
381
+ run_benchmark_btn = gr.Button("Run Benchmark", variant="primary", size="lg")
382
+
383
+ benchmark_summary = gr.Textbox(
384
+ label="Benchmark Results Summary",
385
+ lines=8,
386
+ interactive=False
387
+ )
388
+
389
+ results_table = gr.Dataframe(
390
+ label="Detailed Results",
391
+ headers=["question_id", "question", "category", "difficulty", "correct_answer", "correct_letter", "ai_answer", "ai_choice", "is_correct"]
392
+ )
393
+
394
+ def run_benchmark_wrapper(category, difficulty, num_q, model):
395
+ results_df, summary = app.run_benchmark(category, difficulty, num_q, model)
396
+
397
+ if results_df.empty:
398
+ return summary, pd.DataFrame()
399
+
400
+ # Prepare display dataframe
401
+ display_df = results_df[['question_id', 'question', 'category', 'difficulty', 'correct_answer', 'correct_answer_letter', 'ai_answer', 'is_correct']].copy()
402
+
403
+ # Add the actual AI choice text
404
+ display_df['ai_choice'] = display_df.apply(
405
+ lambda row: results_df[results_df['question_id'] == row['question_id']]['options'].iloc[0].get(row['ai_answer'], 'Unknown')
406
+ if row['ai_answer'] in ['A', 'B', 'C', 'D'] else 'Invalid', axis=1
407
+ )
408
+
409
+ # Reorder columns for better display
410
+ display_df = display_df[['question_id', 'question', 'category', 'difficulty', 'correct_answer', 'correct_answer_letter', 'ai_answer', 'ai_choice', 'is_correct']]
411
+
412
+ return summary, display_df
413
+
414
+ run_benchmark_btn.click(
415
+ fn=run_benchmark_wrapper,
416
+ inputs=[bench_category, bench_difficulty, num_questions, model_choice],
417
+ outputs=[benchmark_summary, results_table]
418
+ )
419
+
420
+ with gr.Tab("🔍 Debug Single Question"):
421
+ with gr.Row():
422
+ debug_question_id = gr.Number(
423
+ label="Question ID",
424
+ value=450,
425
+ precision=0
426
+ )
427
+ debug_model = gr.Dropdown(
428
+ choices=app.model_manager.get_flat_model_list(),
429
+ value=app.model_manager.get_flat_model_list()[0] if app.model_manager.get_flat_model_list() else None,
430
+ label="Model"
431
+ )
432
+ debug_btn = gr.Button("Test Single Question", variant="primary")
433
+
434
+ debug_question_display = gr.Textbox(
435
+ label="Question Details",
436
+ lines=4,
437
+ interactive=False
438
+ )
439
+
440
+ debug_ai_response = gr.Textbox(
441
+ label="Full AI Response",
442
+ lines=8,
443
+ interactive=False
444
+ )
445
+
446
+ debug_result = gr.Textbox(
447
+ label="Parsed Result",
448
+ lines=3,
449
+ interactive=False
450
+ )
451
+
452
+ def debug_single_question(question_id, model):
453
+ if not app.model_manager.get_configured_providers():
454
+ return "Please configure API providers first", "", ""
455
+
456
+ try:
457
+ question_id = int(question_id)
458
+ matching_questions = app.df[app.df['question_id'] == question_id]
459
+
460
+ if matching_questions.empty:
461
+ return f"No question found with ID {question_id}", "", ""
462
+
463
+ question_data = matching_questions.iloc[0].to_dict()
464
+
465
+ question_info = f"""Question ID: {question_id}
466
+ Category: {question_data['category']}
467
+ Difficulty: {question_data['difficulty']}
468
+ Question: {question_data['question']}
469
+
470
+ Options:
471
+ A) {question_data['option_a']}
472
+ B) {question_data['option_b']}
473
+ C) {question_data['option_c']}
474
+ D) {question_data['option_d']}
475
+
476
+ Correct Answer: {question_data['correct_answer']}"""
477
+
478
+ result = app.evaluate_single_question(question_id, model)
479
+
480
+ if "error" in result:
481
+ return question_info, "", f"Error: {result['error']}"
482
+
483
+ ai_response = result.get('ai_response', 'No response')
484
+
485
+ parsed_result = f"""Extracted Answer: {result.get('ai_answer', 'Unknown')}
486
+ Correct Letter: {result.get('correct_answer_letter', 'Unknown')}
487
+ Is Correct: {result.get('is_correct', False)}
488
+ AI Choice Text: {result.get('options', {}).get(result.get('ai_answer', ''), 'Unknown')}"""
489
+
490
+ return question_info, ai_response, parsed_result
491
+
492
+ except Exception as e:
493
+ return f"Error processing question: {str(e)}", "", ""
494
+
495
+ debug_btn.click(
496
+ fn=debug_single_question,
497
+ inputs=[debug_question_id, debug_model],
498
+ outputs=[debug_question_display, debug_ai_response, debug_result]
499
+ )
500
+
501
+ with gr.Tab("📈 Analytics"):
502
+ gr.HTML("""
503
+ <div style="padding: 20px;">
504
+ <h3>Dataset Statistics</h3>
505
+ </div>
506
+ """)
507
+
508
+ # Dataset statistics
509
+ if not app.df.empty:
510
+ stats_html = f"""
511
+ <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 20px; padding: 20px;">
512
+ <div style="background: #f0f0f0; padding: 15px; border-radius: 8px;">
513
+ <h4 style="color: #101010;">Total Questions</h4>
514
+ <p style="font-size: 24px; color: #101010; font-weight: bold;">{len(app.df)}</p>
515
+ </div>
516
+ <div style="background: #f0f0f0; padding: 15px; border-radius: 8px;">
517
+ <h4 style="color: #101010;">Categories</h4>
518
+ <p style="font-size: 24px; color: #101010; font-weight: bold;">{len(app.df['category'].unique())}</p>
519
+ </div>
520
+ <div style="background: #f0f0f0; padding: 15px; border-radius: 8px;">
521
+ <h4 style="color: #101010;">Difficulty Levels</h4>
522
+ <p style="font-size: 24px; color: #101010; font-weight: bold;">{len(app.df['difficulty'].unique())}</p>
523
+ </div>
524
+ </div>
525
+
526
+ <div style="padding: 20px;">
527
+ <h4>Categories Distribution:</h4>
528
+ <ul>
529
+ """
530
+
531
+ for category, count in app.df['category'].value_counts().items():
532
+ stats_html += f"<li>{category}: {count} questions</li>"
533
+
534
+ stats_html += """
535
+ </ul>
536
+
537
+ <h4>Difficulty Distribution:</h4>
538
+ <ul>
539
+ """
540
+
541
+ for difficulty, count in app.df['difficulty'].value_counts().items():
542
+ stats_html += f"<li>{difficulty}: {count} questions</li>"
543
+
544
+ stats_html += "</ul></div>"
545
+
546
+ gr.HTML(stats_html)
547
+
548
+ return interface
549
+
550
+ # Create and launch the interface
551
+ if __name__ == "__main__":
552
+ interface = create_gradio_interface()
553
+ interface.launch(
554
+ server_name="0.0.0.0",
555
+ server_port=7860,
556
+ show_error=True,
557
+ share=False
558
+ )
config.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ import anthropic
3
+ from typing import Dict, Tuple, Optional
4
+ from abc import ABC, abstractmethod
5
+
6
+ class ModelProvider(ABC):
7
+ """Abstract base class for model providers."""
8
+
9
+ @abstractmethod
10
+ def setup_client(self, api_key: str) -> Tuple[bool, str]:
11
+ """Setup the API client with the provided key."""
12
+ pass
13
+
14
+ @abstractmethod
15
+ def generate_response(self, prompt: str, model: str, max_tokens: int = 800) -> str:
16
+ """Generate a response using the specified model."""
17
+ pass
18
+
19
+ @abstractmethod
20
+ def get_available_models(self) -> list:
21
+ """Return list of available models for this provider."""
22
+ pass
23
+
24
+ class OpenAIProvider(ModelProvider):
25
+ """OpenAI API provider implementation."""
26
+
27
+ def __init__(self):
28
+ self.client = None
29
+ self.models = [
30
+ "gpt-3.5-turbo",
31
+ "gpt-4",
32
+ "gpt-4-turbo",
33
+ "gpt-4o",
34
+ "gpt-4o-mini"
35
+ ]
36
+
37
+ def setup_client(self, api_key: str) -> Tuple[bool, str]:
38
+ """Configure OpenAI client with provided API key."""
39
+ if not api_key.strip():
40
+ return False, "OpenAI API key cannot be empty"
41
+
42
+ try:
43
+ self.client = openai.OpenAI(api_key=api_key.strip())
44
+ # Test the connection
45
+ response = self.client.chat.completions.create(
46
+ model="gpt-3.5-turbo",
47
+ messages=[{"role": "user", "content": "Hello"}],
48
+ max_tokens=5
49
+ )
50
+ return True, "OpenAI client configured successfully"
51
+ except Exception as e:
52
+ return False, f"Failed to configure OpenAI client: {str(e)}"
53
+
54
+ def generate_response(self, prompt: str, model: str, max_tokens: int = 800) -> str:
55
+ """Generate response using OpenAI models."""
56
+ if self.client is None:
57
+ raise Exception("OpenAI client not configured")
58
+
59
+ response = self.client.chat.completions.create(
60
+ model=model,
61
+ messages=[
62
+ {"role": "system", "content": "You are a precise mathematician who always provides clear, step-by-step solutions and selects the correct answer from given options."},
63
+ {"role": "user", "content": prompt}
64
+ ],
65
+ max_tokens=max_tokens,
66
+ temperature=0.0
67
+ )
68
+
69
+ return response.choices[0].message.content
70
+
71
+ def get_available_models(self) -> list:
72
+ """Return available OpenAI models."""
73
+ return self.models
74
+
75
+ class ClaudeProvider(ModelProvider):
76
+ """Anthropic Claude API provider implementation."""
77
+
78
+ def __init__(self):
79
+ self.client = None
80
+ self.models = [
81
+ "claude-3-haiku-20240307",
82
+ "claude-3-sonnet-20240229",
83
+ "claude-3-opus-20240229",
84
+ "claude-3-5-sonnet-20241022",
85
+ "claude-3-5-haiku-20241022",
86
+ "claude-sonnet-4-20250514",
87
+ "claude-opus-4-20250514",
88
+ "claude-opus-4-1-20250805"
89
+ ]
90
+
91
+ def setup_client(self, api_key: str) -> Tuple[bool, str]:
92
+ """Configure Anthropic client with provided API key."""
93
+ if not api_key.strip():
94
+ return False, "Anthropic API key cannot be empty"
95
+
96
+ try:
97
+ self.client = anthropic.Anthropic(api_key=api_key.strip())
98
+ # Test the connection
99
+ response = self.client.messages.create(
100
+ model="claude-3-haiku-20240307",
101
+ max_tokens=5,
102
+ messages=[{"role": "user", "content": "Hello"}]
103
+ )
104
+ return True, "Claude client configured successfully"
105
+ except Exception as e:
106
+ return False, f"Failed to configure Claude client: {str(e)}"
107
+
108
+ def generate_response(self, prompt: str, model: str, max_tokens: int = 800) -> str:
109
+ """Generate response using Claude models."""
110
+ if self.client is None:
111
+ raise Exception("Claude client not configured")
112
+
113
+ # Add system prompt for Claude
114
+ system_prompt = "You are a precise mathematician who always provides clear, step-by-step solutions and selects the correct answer from given options."
115
+
116
+ response = self.client.messages.create(
117
+ model=model,
118
+ max_tokens=max_tokens,
119
+ system=system_prompt,
120
+ messages=[{"role": "user", "content": prompt}]
121
+ )
122
+
123
+ return response.content[0].text
124
+
125
+ def get_available_models(self) -> list:
126
+ """Return available Claude models."""
127
+ return self.models
128
+
129
+ class ModelManager:
130
+ """Manages multiple model providers and routing."""
131
+
132
+ def __init__(self):
133
+ self.providers = {
134
+ "openai": OpenAIProvider(),
135
+ "claude": ClaudeProvider()
136
+ }
137
+ self.configured_providers = set()
138
+
139
+ def setup_provider(self, provider_name: str, api_key: str) -> Tuple[bool, str]:
140
+ """Setup a specific provider with API key."""
141
+ if provider_name not in self.providers:
142
+ return False, f"Unknown provider: {provider_name}"
143
+
144
+ success, message = self.providers[provider_name].setup_client(api_key)
145
+
146
+ if success:
147
+ self.configured_providers.add(provider_name)
148
+ else:
149
+ self.configured_providers.discard(provider_name)
150
+
151
+ return success, message
152
+
153
+ def get_provider_for_model(self, model: str) -> Optional[str]:
154
+ """Determine which provider handles the given model."""
155
+ for provider_name, provider in self.providers.items():
156
+ if model in provider.get_available_models():
157
+ return provider_name
158
+ return None
159
+
160
+ def generate_response(self, prompt: str, model: str, max_tokens: int = 800) -> str:
161
+ """Generate response using the appropriate provider for the model."""
162
+ provider_name = self.get_provider_for_model(model)
163
+
164
+ if not provider_name:
165
+ raise Exception(f"No provider found for model: {model}")
166
+
167
+ if provider_name not in self.configured_providers:
168
+ raise Exception(f"Provider {provider_name} not configured")
169
+
170
+ return self.providers[provider_name].generate_response(prompt, model, max_tokens)
171
+
172
+ def get_all_models(self) -> Dict[str, list]:
173
+ """Get all available models grouped by provider."""
174
+ return {
175
+ provider_name: provider.get_available_models()
176
+ for provider_name, provider in self.providers.items()
177
+ }
178
+
179
+ def get_flat_model_list(self) -> list:
180
+ """Get a flat list of all available models."""
181
+ models = []
182
+ for provider in self.providers.values():
183
+ models.extend(provider.get_available_models())
184
+ return models
185
+
186
+ def is_provider_configured(self, provider_name: str) -> bool:
187
+ """Check if a provider is configured."""
188
+ return provider_name in self.configured_providers
189
+
190
+ def get_configured_providers(self) -> list:
191
+ """Get list of configured providers."""
192
+ return list(self.configured_providers)