Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from transformers import MllamaForConditionalGeneration, AutoProcessor | |
| from peft import PeftModel | |
| from huggingface_hub import login | |
| import json | |
| import matplotlib.pyplot as plt | |
| import io | |
| import base64 | |
| def check_environment(): | |
| required_vars = ["HF_TOKEN"] | |
| missing_vars = [var for var in required_vars if var not in os.environ] | |
| if missing_vars: | |
| raise ValueError( | |
| f"Missing required environment variables: {', '.join(missing_vars)}\n" | |
| "Please set the HF_TOKEN environment variable with your Hugging Face token" | |
| ) | |
| # # Login to Hugging Face | |
| # check_environment() | |
| # login(token=os.environ["HF_TOKEN"], add_to_git_credential=True) | |
| # Load model and processor (do this outside the inference function to avoid reloading) | |
| # base_model_path = ( | |
| # "taesiri/BugsBunny-LLama-3.2-11B-Vision-BaseCaptioner-Medium-FullModel" | |
| # ) | |
| # processor = AutoProcessor.from_pretrained(base_model_path) | |
| # model = MllamaForConditionalGeneration.from_pretrained( | |
| # base_model_path, | |
| # torch_dtype=torch.bfloat16, | |
| # device_map="cuda", | |
| # cache_dir="./", | |
| # ) | |
| # # | |
| # odel = PeftModel.from_pretrained(model, lora_weights_path) | |
| from transformers import MllamaForConditionalGeneration, AutoProcessor | |
| import torch | |
| local_model_path = "../merged-llama-3.2-dummy" | |
| # Load model and processor (do this outside the inference function to avoid reloading) | |
| base_model_path = ( | |
| local_model_path | |
| ) | |
| # lora_weights_path = "taesiri/BugsBunny-LLama-3.2-11B-Vision-Base-Medium-LoRA" | |
| processor = AutoProcessor.from_pretrained(base_model_path) | |
| model = MllamaForConditionalGeneration.from_pretrained( | |
| base_model_path, | |
| torch_dtype=torch.bfloat16, | |
| device_map="cuda", | |
| cache_dir="./" | |
| ) | |
| model.tie_weights() | |
| def create_color_palette_image(colors): | |
| if not colors or not isinstance(colors, list): | |
| return None | |
| try: | |
| # Validate color format | |
| for color in colors: | |
| if not isinstance(color, str) or not color.startswith("#"): | |
| return None | |
| # Create figure and axis | |
| fig, ax = plt.subplots(figsize=(10, 2)) | |
| # Create rectangles for each color | |
| for i, color in enumerate(colors): | |
| ax.add_patch(plt.Rectangle((i, 0), 1, 1, facecolor=color)) | |
| # Set the view limits and aspect ratio | |
| ax.set_xlim(0, len(colors)) | |
| ax.set_ylim(0, 1) | |
| ax.set_xticks([]) | |
| ax.set_yticks([]) | |
| return fig # Return the matplotlib figure directly | |
| except Exception as e: | |
| print(f"Error creating color palette: {e}") | |
| return None | |
| def inference(image): | |
| if image is None: | |
| return ["Please provide an image"] * 4 | |
| if not isinstance(image, Image.Image): | |
| try: | |
| image = Image.fromarray(image) | |
| except Exception as e: | |
| print(f"Image conversion error: {e}") | |
| return ["Invalid image format"] * 4 | |
| # Prepare input | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image"}, | |
| {"type": "text", "text": "Analyze this image for fire, smoke, haze, or other related conditions."}, | |
| ], | |
| } | |
| ] | |
| input_text = processor.apply_chat_template(messages, add_generation_prompt=True) | |
| try: | |
| # Move inputs to the correct device | |
| inputs = processor( | |
| image, input_text, add_special_tokens=False, return_tensors="pt" | |
| ).to(model.device) | |
| # Clear CUDA cache after inference | |
| with torch.no_grad(): | |
| output = model.generate(**inputs, max_new_tokens=2048) | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| except Exception as e: | |
| print(f"Inference error: {e}") | |
| return ["Error during inference"] * 4 | |
| # Decode output | |
| result = processor.decode(output[0], skip_special_tokens=True) | |
| print("DEBUG: Full decoded output:", result) | |
| try: | |
| json_str = result.strip().split("assistant\n")[1].strip() | |
| parsed_json = json.loads(json_str) | |
| # Create specific JSON subsets for each section | |
| fire_analysis = { | |
| "predictions": parsed_json.get("predictions", "N/A"), | |
| "description": parsed_json.get("description", "No description available"), | |
| "confidence_scores": parsed_json.get("confidence_score", {}) | |
| } | |
| environment_analysis = { | |
| "environmental_factors": parsed_json.get("environmental_factors", {}) | |
| } | |
| detection_analysis = { | |
| "detections": parsed_json.get("detections", []), | |
| "detection_count": len(parsed_json.get("detections", [])) | |
| } | |
| report_analysis = { | |
| "uncertainty_factors": parsed_json.get("uncertainty_factors", []), | |
| "false_positive_indicators": parsed_json.get("false_positive_indicators", []) | |
| } | |
| return ( | |
| json.dumps(fire_analysis, indent=2), | |
| json.dumps(environment_analysis, indent=2), | |
| json.dumps(detection_analysis, indent=2), | |
| json.dumps(report_analysis, indent=2), | |
| json_str, | |
| "", | |
| "Analysis complete", | |
| parsed_json | |
| ) | |
| except Exception as e: | |
| print("DEBUG: Error processing response:", e) | |
| return ( | |
| "Error processing response", | |
| "", | |
| "", | |
| "", | |
| str(result), | |
| str(e), | |
| "Error", | |
| {} | |
| ) | |
| # Update Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Fire Detection Demo") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image( | |
| type="pil", | |
| label="Upload Image", | |
| elem_id="large-image", | |
| ) | |
| submit_btn = gr.Button("Analyze Image", variant="primary") | |
| # Add examples here | |
| gr.Examples( | |
| examples=[ | |
| "examples/Birch MWF014-0001.png", | |
| "examples/Birch MWF014-0006.png", | |
| "examples/Blackstone PB-0010.png", | |
| ], | |
| inputs=image_input, | |
| label="Example Images", | |
| examples_per_page=4 | |
| ) | |
| with gr.Tabs() as tabs: | |
| with gr.Tab("Analysis Results"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| fire_output = gr.JSON( | |
| label="Fire Details", | |
| lines=4, | |
| ) | |
| with gr.Column(): | |
| environment_output = gr.JSON( | |
| label="Environment Details", | |
| lines=4, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| detection_output = gr.JSON( | |
| label="Detection Details", | |
| lines=4, | |
| ) | |
| with gr.Column(): | |
| report_output = gr.JSON( | |
| label="Report Details", | |
| lines=4, | |
| ) | |
| with gr.Tab("JSON Output", id=0): | |
| json_output = gr.JSON( | |
| label="Detailed JSON Results", | |
| ) | |
| with gr.Tab("Raw Output"): | |
| raw_output = gr.Textbox( | |
| label="Raw JSON Response", | |
| lines=10, | |
| ) | |
| error_box = gr.Textbox(label="Error Messages", visible=False) | |
| status_text = gr.Textbox(label="Status", value="Ready", interactive=False) | |
| submit_btn.click( | |
| fn=inference, | |
| inputs=[image_input], | |
| outputs=[ | |
| fire_output, | |
| environment_output, | |
| detection_output, | |
| report_output, | |
| raw_output, | |
| error_box, | |
| status_text, | |
| json_output, | |
| ], | |
| ) | |
| demo.launch(share=True) |