import gradio as gr import torch from PIL import Image import numpy as np import matplotlib.pyplot as plt import requests import os # URLs dos arquivos do modelo MODEL_URL = "https://huggingface.co/nasa-ibm-ai4science/Surya-1.0/resolve/main/surya.366m.v1.pt" # Nome local do arquivo MODEL_FILE = "surya.366m.v1.pt" # Função para baixar o modelo se não existir def download_model(): if not os.path.exists(MODEL_FILE): print("Baixando pesos do Surya-1.0...") r = requests.get(MODEL_URL) with open(MODEL_FILE, "wb") as f: f.write(r.content) print("Download concluído!") # Baixar modelo download_model() # Carregar modelo PyTorch model = torch.load(MODEL_FILE) model.eval() # Função para gerar heatmap def infer_solar_image_heatmap(img): # Pré-processamento: grayscale, resize 224x224 img = img.convert("L").resize((224, 224)) img_tensor = torch.tensor(np.array(img), dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.0 with torch.no_grad(): outputs = model(img_tensor) # Criar heatmap emb = outputs.squeeze().numpy() heatmap = emb - emb.min() heatmap /= heatmap.max() + 1e-8 # normalização 0-1 plt.imshow(heatmap, cmap='hot') plt.axis('off') plt.tight_layout() return plt.gcf() # Interface Gradio interface = gr.Interface( fn=infer_solar_image_heatmap, inputs=gr.Image(type="pil"), outputs=gr.Plot(label="Heatmap do embedding Surya"), title="Playground Surya-1.0 com Heatmap", description="Upload de imagem solar → visualize heatmap gerado pelo Surya-1.0" ) interface.launch()