Anthony Liang commited on
Commit
a4ffa6f
·
1 Parent(s): 4d32b53

first commit

Browse files
Files changed (3) hide show
  1. README.md +6 -5
  2. app.py +508 -0
  3. requirements.txt +32 -0
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: Rewardeval Ui
3
- emoji: 💻
4
- colorFrom: indigo
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 6.1.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Rewardfm Eval Ui
3
+ emoji: 🔥
4
+ colorFrom: gray
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 6.0.0
8
  app_file: app.py
9
  pinned: false
10
+ short_description: UI for rfm evals
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Gradio app for RFM (Reward Foundation Model) inference visualization.
4
+ Supports single video (progress/success) and dual video (preference/similarity) predictions.
5
+ Uses eval server for inference instead of loading models locally.
6
+ """
7
+
8
+ import os
9
+ import tempfile
10
+ from pathlib import Path
11
+ from typing import Optional, Tuple
12
+
13
+ import gradio as gr
14
+ import spaces # Required for ZeroGPU
15
+ import matplotlib
16
+ matplotlib.use('Agg') # Use non-interactive backend
17
+ import matplotlib.pyplot as plt
18
+ import numpy as np
19
+ import requests
20
+ from PIL import Image
21
+ import decord
22
+
23
+ from rfm.data.dataset_types import Trajectory, ProgressSample, PreferenceSample
24
+ from rfm.evals.eval_utils import build_payload, post_batch_npy
25
+
26
+ # Global server state
27
+ _server_state = {
28
+ "server_url": None,
29
+ }
30
+
31
+ def check_server_health(server_url: str) -> Tuple[str, Optional[dict]]:
32
+ """Check server health and get model info."""
33
+ if not server_url:
34
+ return "Please provide a server URL.", None
35
+
36
+ try:
37
+ url = server_url.rstrip("/") + "/health"
38
+ response = requests.get(url, timeout=5.0)
39
+ response.raise_for_status()
40
+ health_data = response.json()
41
+
42
+ # Also try to get GPU status for more info
43
+ try:
44
+ status_url = server_url.rstrip("/") + "/gpu_status"
45
+ status_response = requests.get(status_url, timeout=5.0)
46
+ if status_response.status_code == 200:
47
+ status_data = status_response.json()
48
+ health_data.update(status_data)
49
+ except:
50
+ pass
51
+
52
+ _server_state["server_url"] = server_url
53
+ return f"Server connected: {health_data.get('available_gpus', 0)}/{health_data.get('total_gpus', 0)} GPUs available", health_data
54
+ except requests.exceptions.RequestException as e:
55
+ return f"Error connecting to server: {str(e)}", None
56
+
57
+
58
+ def extract_frames(video_path: str, max_frames: int = 16, fps: float = 1.0) -> np.ndarray:
59
+ """Extract frames from video file as numpy array (T, H, W, C)."""
60
+ if video_path is None:
61
+ return None
62
+
63
+ if isinstance(video_path, tuple):
64
+ video_path = video_path[0]
65
+
66
+ if not os.path.exists(video_path):
67
+ return None
68
+
69
+ try:
70
+ vr = decord.VideoReader(video_path, num_threads=1)
71
+ total_frames = len(vr)
72
+
73
+ if total_frames <= max_frames:
74
+ frame_indices = list(range(total_frames))
75
+ else:
76
+ frame_indices = [
77
+ int(i * total_frames / max_frames)
78
+ for i in range(max_frames)
79
+ ]
80
+
81
+ frames_array = vr.get_batch(frame_indices).asnumpy() # Shape: (T, H, W, C)
82
+ del vr
83
+ return frames_array
84
+ except Exception as e:
85
+ print(f"Error extracting frames: {e}")
86
+ return None
87
+
88
+
89
+ def process_single_video(
90
+ video_path: str,
91
+ task_text: str = "Complete the task",
92
+ server_url: str = "",
93
+ fps: float = 1.0,
94
+ ) -> Tuple[Optional[str], Optional[str], Optional[str]]:
95
+ """Process single video for progress and success predictions using eval server."""
96
+ if not server_url:
97
+ return None, None, "Please provide a server URL and check connection first."
98
+
99
+ if not _server_state.get("server_url"):
100
+ return None, None, "Server not connected. Please check server connection first."
101
+
102
+ if video_path is None:
103
+ return None, None, "Please provide a video."
104
+
105
+ try:
106
+ frames_array = extract_frames(video_path, max_frames=16, fps=fps)
107
+ if frames_array is None or frames_array.size == 0:
108
+ return None, None, "Could not extract frames from video."
109
+
110
+ # Convert frames to (T, H, W, C) numpy array with uint8 values
111
+ if frames_array.dtype != np.uint8:
112
+ frames_array = np.clip(frames_array, 0, 255).astype(np.uint8)
113
+
114
+ num_frames = frames_array.shape[0]
115
+ frames_shape = frames_array.shape # (T, H, W, C)
116
+
117
+ # Create target progress (placeholder - would be None in real use)
118
+ target_progress = np.linspace(0.0, 1.0, num=num_frames).tolist()
119
+ success_label = [1.0 if prog > 0.5 else 0.0 for prog in target_progress]
120
+
121
+ # Create Trajectory
122
+ trajectory = Trajectory(
123
+ task=task_text,
124
+ frames=frames_array,
125
+ frames_shape=frames_shape,
126
+ target_progress=target_progress,
127
+ success_label=success_label,
128
+ metadata={"source": "gradio_app"},
129
+ )
130
+
131
+ # Create ProgressSample
132
+ progress_sample = ProgressSample(
133
+ trajectory=trajectory,
134
+ data_gen_strategy="demo",
135
+ )
136
+
137
+ # Build payload and send to server
138
+ files, sample_data = build_payload([progress_sample])
139
+ response = post_batch_npy(server_url, files, sample_data, timeout_s=120.0)
140
+
141
+ # Process response
142
+ outputs_progress = response.get("outputs_progress", {})
143
+ progress_pred = outputs_progress.get("progress_pred", [])
144
+
145
+ # Extract progress predictions
146
+ if progress_pred and len(progress_pred) > 0:
147
+ progress_array = np.array(progress_pred[0]) # First sample
148
+ else:
149
+ progress_array = np.array([])
150
+
151
+ # Create plots
152
+ progress_plot = create_progress_plot(progress_array, num_frames)
153
+ success_plot = None # Success predictions not always available from server
154
+
155
+ info_text = f"**Frames processed:** {num_frames}\n"
156
+ if len(progress_array) > 0:
157
+ info_text += f"**Final progress:** {progress_array[-1]:.3f}\n"
158
+
159
+ return progress_plot, success_plot, info_text
160
+
161
+ except Exception as e:
162
+ return None, None, f"Error processing video: {str(e)}"
163
+
164
+
165
+ def process_dual_videos(
166
+ video_a_path: str,
167
+ video_b_path: str,
168
+ task_text: str = "Complete the task",
169
+ prediction_type: str = "preference",
170
+ server_url: str = "",
171
+ fps: float = 1.0,
172
+ ) -> Tuple[Optional[str], Optional[str]]:
173
+ """Process two videos for preference or similarity prediction using eval server."""
174
+ if not server_url:
175
+ return "Please provide a server URL and check connection first.", None
176
+
177
+ if not _server_state.get("server_url"):
178
+ return "Server not connected. Please check server connection first.", None
179
+
180
+ if video_a_path is None or video_b_path is None:
181
+ return "Please provide both videos.", None
182
+
183
+ try:
184
+ frames_array_a = extract_frames(video_a_path, max_frames=16, fps=fps)
185
+ frames_array_b = extract_frames(video_b_path, max_frames=16, fps=fps)
186
+
187
+ if frames_array_a is None or frames_array_a.size == 0:
188
+ return "Could not extract frames from video A.", None
189
+ if frames_array_b is None or frames_array_b.size == 0:
190
+ return "Could not extract frames from video B.", None
191
+
192
+ # Convert frames to uint8
193
+ if frames_array_a.dtype != np.uint8:
194
+ frames_array_a = np.clip(frames_array_a, 0, 255).astype(np.uint8)
195
+ if frames_array_b.dtype != np.uint8:
196
+ frames_array_b = np.clip(frames_array_b, 0, 255).astype(np.uint8)
197
+
198
+ num_frames_a = frames_array_a.shape[0]
199
+ num_frames_b = frames_array_b.shape[0]
200
+ frames_shape_a = frames_array_a.shape
201
+ frames_shape_b = frames_array_b.shape
202
+
203
+ # Create target progress for both trajectories
204
+ target_progress_a = np.linspace(0.0, 1.0, num=num_frames_a).tolist()
205
+ target_progress_b = np.linspace(0.0, 1.0, num=num_frames_b).tolist()
206
+ success_label_a = [1.0 if prog > 0.5 else 0.0 for prog in target_progress_a]
207
+ success_label_b = [1.0 if prog > 0.5 else 0.0 for prog in target_progress_b]
208
+
209
+ # Create trajectories
210
+ trajectory_a = Trajectory(
211
+ task=task_text,
212
+ frames=frames_array_a,
213
+ frames_shape=frames_shape_a,
214
+ target_progress=target_progress_a,
215
+ success_label=success_label_a,
216
+ metadata={"source": "gradio_app", "trajectory": "A"},
217
+ )
218
+
219
+ trajectory_b = Trajectory(
220
+ task=task_text,
221
+ frames=frames_array_b,
222
+ frames_shape=frames_shape_b,
223
+ target_progress=target_progress_b,
224
+ success_label=success_label_b,
225
+ metadata={"source": "gradio_app", "trajectory": "B"},
226
+ )
227
+
228
+ if prediction_type == "preference":
229
+ # Create PreferenceSample (A = chosen, B = rejected)
230
+ preference_sample = PreferenceSample(
231
+ chosen_trajectory=trajectory_a,
232
+ rejected_trajectory=trajectory_b,
233
+ data_gen_strategy="demo",
234
+ )
235
+
236
+ # Build payload and send to server
237
+ files, sample_data = build_payload([preference_sample])
238
+ response = post_batch_npy(server_url, files, sample_data, timeout_s=120.0)
239
+
240
+ # Process response
241
+ outputs_preference = response.get("outputs_preference", {})
242
+ predictions = outputs_preference.get("predictions", [])
243
+ prediction_probs = outputs_preference.get("prediction_probs", [])
244
+
245
+ result_text = f"**Preference Prediction:**\n"
246
+ if prediction_probs and len(prediction_probs) > 0:
247
+ prob = prediction_probs[0]
248
+ result_text += f"- Probability (A preferred): {prob:.3f}\n"
249
+ result_text += f"- Interpretation: {'Video A is preferred' if prob > 0.5 else 'Video B is preferred'}\n"
250
+ else:
251
+ result_text += "Could not extract preference prediction from server response.\n"
252
+
253
+ else: # similarity - not yet implemented in eval server response format
254
+ result_text = "Similarity prediction not yet supported in eval server response format."
255
+
256
+ # Create comparison plot
257
+ frames_a_list = [Image.fromarray(frame) for frame in frames_array_a]
258
+ frames_b_list = [Image.fromarray(frame) for frame in frames_array_b]
259
+ comparison_plot = create_comparison_plot(frames_a_list, frames_b_list, prediction_type)
260
+
261
+ return result_text, comparison_plot
262
+
263
+ except Exception as e:
264
+ return f"Error processing videos: {str(e)}", None
265
+
266
+
267
+ def create_progress_plot(progress_pred: np.ndarray, num_frames: int) -> str:
268
+ """Create progress prediction plot."""
269
+ plt.rcParams['font.family'] = 'DejaVu Sans'
270
+ plt.rcParams['font.size'] = 16
271
+
272
+ fig, ax = plt.subplots(figsize=(10, 6))
273
+
274
+ if len(progress_pred) > 0:
275
+ frame_indices = np.arange(len(progress_pred))
276
+ ax.plot(frame_indices, progress_pred, 'b-', linewidth=3, marker='o', markersize=8, label='Progress Prediction')
277
+ else:
278
+ ax.text(0.5, 0.5, 'No progress prediction available',
279
+ horizontalalignment='center', verticalalignment='center',
280
+ transform=ax.transAxes, fontsize=18)
281
+
282
+ ax.set_xlabel('Frame Index', fontsize=18, fontweight='bold')
283
+ ax.set_ylabel('Progress (0-1)', fontsize=18, fontweight='bold')
284
+ ax.set_title('Progress Prediction', fontsize=20, fontweight='bold')
285
+ ax.set_ylim([0, 1])
286
+ ax.legend(fontsize=14)
287
+
288
+ plt.tight_layout()
289
+
290
+ tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
291
+ plt.savefig(tmp_file.name, dpi=150, bbox_inches='tight')
292
+ plt.close()
293
+
294
+ return tmp_file.name
295
+
296
+
297
+ def create_success_plot(success_probs: np.ndarray, num_frames: int) -> str:
298
+ """Create success probability plot."""
299
+ plt.rcParams['font.family'] = 'DejaVu Sans'
300
+ plt.rcParams['font.size'] = 16
301
+
302
+ fig, ax = plt.subplots(figsize=(10, 6))
303
+
304
+ if len(success_probs) > 0:
305
+ frame_indices = np.arange(len(success_probs))
306
+ ax.plot(frame_indices, success_probs, 'g-', linewidth=3, marker='s', markersize=8, label='Success Probability')
307
+ ax.axhline(y=0.5, color='r', linestyle='--', linewidth=2, label='Decision Threshold (0.5)')
308
+ else:
309
+ ax.text(0.5, 0.5, 'No success prediction available',
310
+ horizontalalignment='center', verticalalignment='center',
311
+ transform=ax.transAxes, fontsize=18)
312
+
313
+ ax.set_xlabel('Frame Index', fontsize=18, fontweight='bold')
314
+ ax.set_ylabel('Success Probability (0-1)', fontsize=18, fontweight='bold')
315
+ ax.set_title('Success Prediction', fontsize=20, fontweight='bold')
316
+ ax.set_ylim([0, 1])
317
+ ax.legend(fontsize=14)
318
+
319
+ plt.tight_layout()
320
+
321
+ tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
322
+ plt.savefig(tmp_file.name, dpi=150, bbox_inches='tight')
323
+ plt.close()
324
+
325
+ return tmp_file.name
326
+
327
+ def create_comparison_plot(frames_a: list, frames_b: list, prediction_type: str) -> str:
328
+ """Create side-by-side comparison plot of two videos."""
329
+ plt.rcParams['font.family'] = 'DejaVu Sans'
330
+ plt.rcParams['font.size'] = 16
331
+
332
+ fig, axes = plt.subplots(2, min(8, max(len(frames_a), len(frames_b))), figsize=(16, 4))
333
+
334
+ if len(axes.shape) == 1:
335
+ axes = axes.reshape(2, -1)
336
+
337
+ # Sample frames to display
338
+ num_display = min(8, max(len(frames_a), len(frames_b)))
339
+ indices_a = np.linspace(0, len(frames_a) - 1, num_display, dtype=int) if len(frames_a) > 1 else [0]
340
+ indices_b = np.linspace(0, len(frames_b) - 1, num_display, dtype=int) if len(frames_b) > 1 else [0]
341
+
342
+ # Display frames from video A (top row)
343
+ for idx, frame_idx in enumerate(indices_a):
344
+ if frame_idx < len(frames_a):
345
+ axes[0, idx].imshow(frames_a[frame_idx])
346
+ axes[0, idx].axis('off')
347
+ axes[0, idx].set_title(f'Frame {frame_idx}', fontsize=12)
348
+
349
+ # Display frames from video B (bottom row)
350
+ for idx, frame_idx in enumerate(indices_b):
351
+ if frame_idx < len(frames_b):
352
+ axes[1, idx].imshow(frames_b[frame_idx])
353
+ axes[1, idx].axis('off')
354
+ axes[1, idx].set_title(f'Frame {frame_idx}', fontsize=12)
355
+
356
+ # Add row labels
357
+ fig.text(0.02, 0.75, 'Video A', rotation=90, fontsize=18, fontweight='bold', va='center')
358
+ fig.text(0.02, 0.25, 'Video B', rotation=90, fontsize=18, fontweight='bold', va='center')
359
+
360
+ title = f"{prediction_type.capitalize()} Comparison: Video A vs Video B"
361
+ fig.suptitle(title, fontsize=20, fontweight='bold', y=0.98)
362
+
363
+ plt.tight_layout()
364
+
365
+ # Save to temporary file
366
+ tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
367
+ plt.savefig(tmp_file.name, dpi=150, bbox_inches='tight')
368
+ plt.close()
369
+
370
+ return tmp_file.name
371
+
372
+
373
+ # Create Gradio interface
374
+ try:
375
+ # Try with theme (Gradio 4.0+)
376
+ demo = gr.Blocks(title="RFM Inference Visualizer", theme=gr.themes.Soft())
377
+ except TypeError:
378
+ # Fallback for older Gradio versions without theme support
379
+ demo = gr.Blocks(title="RFM Inference Visualizer")
380
+
381
+ with demo:
382
+ gr.Markdown(
383
+ """
384
+ # RFM (Reward Foundation Model) Inference Visualizer
385
+
386
+ Visualize progress, success, preference, and similarity predictions from the Reward Foundation Model.
387
+
388
+ **Features:**
389
+ - **Single Video**: Get progress and success predictions
390
+ - **Dual Videos**: Compare two videos with preference or similarity predictions
391
+
392
+ **Note:** This app connects to an eval server. Please provide the server URL and check connection before use.
393
+ """
394
+ )
395
+
396
+ with gr.Tab("Server Setup"):
397
+ gr.Markdown("### Connect to Eval Server")
398
+ gr.Markdown("Enter the eval server URL and check connection.")
399
+
400
+ with gr.Row():
401
+ with gr.Column(scale=3):
402
+ server_url_input = gr.Textbox(
403
+ label="Server URL",
404
+ placeholder="http://40.119.56.66:8000",
405
+ value="http://40.119.56.66:8000",
406
+ interactive=True,
407
+ )
408
+ with gr.Column(scale=1):
409
+ check_connection_btn = gr.Button("Check Connection", variant="primary", size="sm")
410
+
411
+ server_status = gr.Markdown("Enter server URL and click 'Check Connection'")
412
+
413
+ def on_check_connection(server_url: str):
414
+ """Handle server connection check."""
415
+ status, health_data = check_server_health(server_url)
416
+ return status
417
+
418
+ check_connection_btn.click(
419
+ fn=on_check_connection,
420
+ inputs=[server_url_input],
421
+ outputs=[server_status],
422
+ )
423
+
424
+ with gr.Tab("Progress Prediction"):
425
+ gr.Markdown("### Progress & Success Prediction")
426
+ with gr.Row():
427
+ with gr.Column():
428
+ single_video_input = gr.Video(label="Upload Video", height=300)
429
+ task_text_input = gr.Textbox(
430
+ label="Task Description",
431
+ placeholder="Describe the task (e.g., 'Pick up the red block')",
432
+ value="Complete the task",
433
+ )
434
+ fps_input_single = gr.Slider(
435
+ label="FPS (Frames Per Second)",
436
+ minimum=0.1,
437
+ maximum=10.0,
438
+ value=1.0,
439
+ step=0.1,
440
+ info="Frames per second to extract from video (higher = more frames)",
441
+ )
442
+ analyze_single_btn = gr.Button("Analyze Video", variant="primary")
443
+
444
+ with gr.Column():
445
+ progress_plot = gr.Image(label="Progress Prediction", height=400)
446
+ success_plot = gr.Image(label="Success Prediction", height=400)
447
+ info_output = gr.Markdown("")
448
+
449
+ analyze_single_btn.click(
450
+ fn=process_single_video,
451
+ inputs=[single_video_input, task_text_input, server_url_input, fps_input_single],
452
+ outputs=[progress_plot, success_plot, info_output],
453
+ )
454
+
455
+ with gr.Tab("Preference/Similarity Analysis"):
456
+ gr.Markdown("### Preference & Similarity Prediction")
457
+ with gr.Row():
458
+ with gr.Column():
459
+ video_a_input = gr.Video(label="Video A", height=250)
460
+ video_b_input = gr.Video(label="Video B", height=250)
461
+ task_text_dual = gr.Textbox(
462
+ label="Task Description",
463
+ placeholder="Describe the task",
464
+ value="Complete the task",
465
+ )
466
+ prediction_type = gr.Radio(
467
+ choices=["preference", "similarity"],
468
+ value="preference",
469
+ label="Prediction Type",
470
+ )
471
+ fps_input_dual = gr.Slider(
472
+ label="FPS (Frames Per Second)",
473
+ minimum=0.1,
474
+ maximum=10.0,
475
+ value=1.0,
476
+ step=0.1,
477
+ info="Frames per second to extract from videos (higher = more frames)",
478
+ )
479
+ analyze_dual_btn = gr.Button("Compare Videos", variant="primary")
480
+
481
+ with gr.Column():
482
+ result_text = gr.Markdown("")
483
+ comparison_plot = gr.Image(label="Video Comparison", height=500)
484
+
485
+ analyze_dual_btn.click(
486
+ fn=process_dual_videos,
487
+ inputs=[video_a_input, video_b_input, task_text_dual, prediction_type, server_url_input, fps_input_dual],
488
+ outputs=[result_text, comparison_plot],
489
+ )
490
+
491
+
492
+ def main():
493
+ """Launch the Gradio app."""
494
+ import sys
495
+
496
+ # Check if reload mode is requested
497
+ watch_files = os.getenv("GRADIO_WATCH", "0") == "1" or "--reload" in sys.argv
498
+
499
+ demo.launch(
500
+ server_name="0.0.0.0",
501
+ server_port=7860,
502
+ share=False,
503
+ show_error=True, # Show full error messages
504
+ )
505
+
506
+
507
+ if __name__ == "__main__":
508
+ main()
requirements.txt ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Requirements for RFM Eval UI Gradio App
2
+
3
+ # Core dependencies
4
+ matplotlib>=3.5.0
5
+ numpy>=1.21.0
6
+ torch>=2.0.0
7
+ PyYAML>=6.0
8
+ Pillow>=9.0.0
9
+
10
+ # HuggingFace
11
+ huggingface-hub>=0.16.0
12
+ transformers>=4.30.0
13
+
14
+ # Sentence transformers (for ReWiND models)
15
+ sentence-transformers>=2.2.0
16
+
17
+ # Qwen VL utilities
18
+ qwen-vl-utils
19
+
20
+ # Video processing
21
+ opencv-python-headless>=4.5.0
22
+ decord>=0.6.0 # For video frame extraction (same as preprocess_datasets.py)
23
+
24
+ # Development tools (optional, for auto-reload)
25
+ watchfiles # For file watching during development
26
+
27
+ # RFM package (installed from git repository)
28
+ # For local development, you can also install with: pip install -e ../ (from parent directory)
29
+ git+https://github.com/aliang8/reward_fm.git@93b1ad4b5a530fb32c234bf926b659105e676d00
30
+
31
+ # Make sure a newer version of gradio is installed
32
+ gradio==4.44.0