alinasdkey's picture
Update app.py
c4a118e verified
import os
import torch
from unsloth import FastVisionModel
from peft import PeftModel
from transformers import AutoProcessor
from PIL import Image
import gradio as gr
# Load base LLaMA vision model
model_name = "unsloth/Llama-3.2-11B-Vision-Instruct"
lora_repo = "alinasdkey/unsloth-pret-lora"
# Load base model and processor
model, processor = FastVisionModel.from_pretrained(
model_name=model_name,
device_map="auto",
load_in_4bit=False,
load_in_8bit=True,
)
# Apply LoRA adapter
model = PeftModel.from_pretrained(model, model_id=lora_repo)
# Set to inference mode
FastVisionModel.for_inference(model)
# Inference function
def describe_image(image, instruction):
# Load and preprocess image
image = image.convert("RGB")
inputs = processor(images=image, return_tensors="pt").to(model.device)
# Create input prompt with instruction
prompt = instruction if instruction else "Describe this graph."
# Tokenize text prompt
input_ids = processor.tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
# Extract necessary vision inputs
pixel_values = inputs["pixel_values"]
aspect_ratio_ids = inputs.get("aspect_ratio_ids")
aspect_ratio_mask = inputs.get("aspect_ratio_mask")
# Generate model output
outputs = model.generate(
input_ids=input_ids,
pixel_values=pixel_values,
aspect_ratio_ids=aspect_ratio_ids,
aspect_ratio_mask=aspect_ratio_mask,
max_new_tokens=256,
do_sample=False,
temperature=0.2,
top_p=0.95,
)
# Decode and return result
return processor.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
# Gradio Interface
gr.Interface(
fn=describe_image,
inputs=[
gr.Image(type="pil", label="Upload a Graph Image"),
gr.Textbox(label="Instruction (e.g. Summarize this graph)")
],
outputs="text",
title="Welcome to the Graph Description AI: Pret",
description="Upload a graph and get insightful analysis!"
).launch()