import os import gradio as gr import torch from PIL import Image from transformers import MllamaForConditionalGeneration, AutoProcessor from peft import PeftModel from huggingface_hub import login import json import matplotlib.pyplot as plt import io import base64 def check_environment(): required_vars = ["HF_TOKEN"] missing_vars = [var for var in required_vars if var not in os.environ] if missing_vars: raise ValueError( f"Missing required environment variables: {', '.join(missing_vars)}\n" "Please set the HF_TOKEN environment variable with your Hugging Face token" ) # # Login to Hugging Face # check_environment() # login(token=os.environ["HF_TOKEN"], add_to_git_credential=True) # Load model and processor (do this outside the inference function to avoid reloading) # base_model_path = ( # "taesiri/BugsBunny-LLama-3.2-11B-Vision-BaseCaptioner-Medium-FullModel" # ) # processor = AutoProcessor.from_pretrained(base_model_path) # model = MllamaForConditionalGeneration.from_pretrained( # base_model_path, # torch_dtype=torch.bfloat16, # device_map="cuda", # cache_dir="./", # ) # # # odel = PeftModel.from_pretrained(model, lora_weights_path) from transformers import MllamaForConditionalGeneration, AutoProcessor import torch local_model_path = "../merged-llama-3.2-dummy" # Load model and processor (do this outside the inference function to avoid reloading) base_model_path = ( local_model_path ) # lora_weights_path = "taesiri/BugsBunny-LLama-3.2-11B-Vision-Base-Medium-LoRA" processor = AutoProcessor.from_pretrained(base_model_path) model = MllamaForConditionalGeneration.from_pretrained( base_model_path, torch_dtype=torch.bfloat16, device_map="cuda", cache_dir="./" ) model.tie_weights() def create_color_palette_image(colors): if not colors or not isinstance(colors, list): return None try: # Validate color format for color in colors: if not isinstance(color, str) or not color.startswith("#"): return None # Create figure and axis fig, ax = plt.subplots(figsize=(10, 2)) # Create rectangles for each color for i, color in enumerate(colors): ax.add_patch(plt.Rectangle((i, 0), 1, 1, facecolor=color)) # Set the view limits and aspect ratio ax.set_xlim(0, len(colors)) ax.set_ylim(0, 1) ax.set_xticks([]) ax.set_yticks([]) return fig # Return the matplotlib figure directly except Exception as e: print(f"Error creating color palette: {e}") return None def inference(image): if image is None: return ["Please provide an image"] * 4 if not isinstance(image, Image.Image): try: image = Image.fromarray(image) except Exception as e: print(f"Image conversion error: {e}") return ["Invalid image format"] * 4 # Prepare input messages = [ { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": "Analyze this image for fire, smoke, haze, or other related conditions."}, ], } ] input_text = processor.apply_chat_template(messages, add_generation_prompt=True) try: # Move inputs to the correct device inputs = processor( image, input_text, add_special_tokens=False, return_tensors="pt" ).to(model.device) # Clear CUDA cache after inference with torch.no_grad(): output = model.generate(**inputs, max_new_tokens=2048) if torch.cuda.is_available(): torch.cuda.empty_cache() except Exception as e: print(f"Inference error: {e}") return ["Error during inference"] * 4 # Decode output result = processor.decode(output[0], skip_special_tokens=True) print("DEBUG: Full decoded output:", result) try: json_str = result.strip().split("assistant\n")[1].strip() parsed_json = json.loads(json_str) # Create specific JSON subsets for each section fire_analysis = { "predictions": parsed_json.get("predictions", "N/A"), "description": parsed_json.get("description", "No description available"), "confidence_scores": parsed_json.get("confidence_score", {}) } environment_analysis = { "environmental_factors": parsed_json.get("environmental_factors", {}) } detection_analysis = { "detections": parsed_json.get("detections", []), "detection_count": len(parsed_json.get("detections", [])) } report_analysis = { "uncertainty_factors": parsed_json.get("uncertainty_factors", []), "false_positive_indicators": parsed_json.get("false_positive_indicators", []) } return ( json.dumps(fire_analysis, indent=2), json.dumps(environment_analysis, indent=2), json.dumps(detection_analysis, indent=2), json.dumps(report_analysis, indent=2), json_str, "", "Analysis complete", parsed_json ) except Exception as e: print("DEBUG: Error processing response:", e) return ( "Error processing response", "", "", "", str(result), str(e), "Error", {} ) # Update Gradio interface with gr.Blocks() as demo: gr.Markdown("# Fire Detection Demo") with gr.Row(): with gr.Column(scale=1): image_input = gr.Image( type="pil", label="Upload Image", elem_id="large-image", ) submit_btn = gr.Button("Analyze Image", variant="primary") # Add examples here gr.Examples( examples=[ "examples/Birch MWF014-0001.png", "examples/Birch MWF014-0006.png", "examples/Blackstone PB-0010.png", ], inputs=image_input, label="Example Images", examples_per_page=4 ) with gr.Tabs() as tabs: with gr.Tab("Analysis Results"): with gr.Row(): with gr.Column(): fire_output = gr.JSON( label="Fire Details", lines=4, ) with gr.Column(): environment_output = gr.JSON( label="Environment Details", lines=4, ) with gr.Row(): with gr.Column(): detection_output = gr.JSON( label="Detection Details", lines=4, ) with gr.Column(): report_output = gr.JSON( label="Report Details", lines=4, ) with gr.Tab("JSON Output", id=0): json_output = gr.JSON( label="Detailed JSON Results", ) with gr.Tab("Raw Output"): raw_output = gr.Textbox( label="Raw JSON Response", lines=10, ) error_box = gr.Textbox(label="Error Messages", visible=False) status_text = gr.Textbox(label="Status", value="Ready", interactive=False) submit_btn.click( fn=inference, inputs=[image_input], outputs=[ fire_output, environment_output, detection_output, report_output, raw_output, error_box, status_text, json_output, ], ) demo.launch(share=True)