import gradio as gr from transformers import AutoImageProcessor, AutoModelForImageClassification from datasets import load_dataset import random import torch # Load model from Hugging Face model_name = "Jordiett/convnextv2-geoguessr" processor = AutoImageProcessor.from_pretrained(model_name) model = AutoModelForImageClassification.from_pretrained(model_name) # Load dataset print("Loading GeoGuessr dataset...") dataset = load_dataset("marcelomoreno26/geoguessr", split="test") print(f"Loaded {len(dataset)} test images") # List of countries countries = list(model.config.id2label.values()) # Game state class GameState: def __init__(self): self.player_score = 0 self.ai_score = 0 self.rounds = 0 self.current_image = None self.correct_country = None self.ai_prediction = None self.ai_top3 = None self.options = [] self.used_indices = [] game = GameState() def get_ai_prediction(image): """Get AI prediction""" inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predicted_id = logits.argmax(-1).item() probabilities = torch.nn.functional.softmax(logits, dim=-1)[0] # Top 3 predictions top3_prob, top3_idx = torch.topk(probabilities, 3) top3_countries = [(model.config.id2label[idx.item()], prob.item()) for idx, prob in zip(top3_idx, top3_prob)] return model.config.id2label[predicted_id], top3_countries def generate_options(correct_country, ai_prediction): """Generate 4 options (1 correct + 3 incorrect)""" options = [correct_country] # Add AI prediction if it's wrong (makes it more interesting) if ai_prediction != correct_country and ai_prediction not in options: options.append(ai_prediction) # Fill remaining slots with random countries other_countries = [c for c in countries if c not in options] needed = 4 - len(options) options.extend(random.sample(other_countries, needed)) random.shuffle(options) return options def new_round(): """Start a new round with a random image from the dataset""" # Select a random image that hasn't been used yet available_indices = [i for i in range(len(dataset)) if i not in game.used_indices] if len(available_indices) == 0: # Reset if all images have been used game.used_indices = [] available_indices = list(range(len(dataset))) idx = random.choice(available_indices) game.used_indices.append(idx) # Get image and label from dataset sample = dataset[idx] image = sample["image"] game.correct_country = sample["label"] game.current_image = image # Get AI prediction (but don't show it yet!) ai_pred, top3 = get_ai_prediction(image) game.ai_prediction = ai_pred game.ai_top3 = top3 # Generate options game.options = generate_options(game.correct_country, ai_pred) return ( image, "🌍 **Where do you think this image is from?**\n\nMake your choice before seeing the AI's prediction!", gr.update(choices=game.options, value=None, visible=True), gr.update(visible=True), f"🎮 Player: {game.player_score} | 🤖 AI: {game.ai_score} | 🎯 Rounds: {game.rounds}", "" # Clear previous result ) def check_answer(player_choice): """Check player's answer""" if player_choice is None: return "⚠️ Please select an option!", gr.update(visible=True) game.rounds += 1 # Check if player is correct player_correct = (player_choice == game.correct_country) if player_correct: game.player_score += 1 # Check if AI is correct ai_correct = (game.ai_prediction == game.correct_country) if ai_correct: game.ai_score += 1 # Result message result = f"## 🎯 Round {game.rounds} Result\n\n" result += f"**Correct country:** {game.correct_country}\n\n" # Show comparison if player_correct and ai_correct: result += "🎉 **It's a tie!** Both you and the AI got it right!\n" elif player_correct: result += "🏆 **You win!** The AI was wrong.\n" elif ai_correct: result += "🤖 **AI wins!** You were wrong.\n" else: result += "❌ **Both failed!**\n" result += f"\n**Your answer:** {player_choice} {'✅' if player_correct else '❌'}\n" result += f"**AI prediction:** {game.ai_prediction} {'✅' if ai_correct else '❌'}\n" # Show AI's top 3 predictions result += f"\n**AI's top 3 predictions:**\n" for country, prob in game.ai_top3: result += f"- {country}: {prob*100:.1f}%\n" # Calculate win rate if game.rounds > 0: player_rate = (game.player_score / game.rounds) * 100 ai_rate = (game.ai_score / game.rounds) * 100 result += f"\n---\n" result += f"**Your accuracy:** {player_rate:.1f}% ({game.player_score}/{game.rounds})\n" result += f"**AI accuracy:** {ai_rate:.1f}% ({game.ai_score}/{game.rounds})\n" return result, gr.update(visible=True) def reset_game(): """Reset the game""" game.player_score = 0 game.ai_score = 0 game.rounds = 0 game.used_indices = [] return ( None, "🎮 **Game reset!** Click 'New Round' to start playing.", gr.update(choices=[], value=None, visible=False), gr.update(visible=False), "🎮 Player: 0 | 🤖 AI: 0 | 🎯 Rounds: 0", "" ) # Gradio Interface with gr.Blocks(theme=gr.themes.Soft(), title="GeoGuessr: Player vs AI") as demo: gr.Markdown(""" # 🌍 GeoGuessr: Player vs AI Compete against an AI trained with ConvNeXt V2 to guess countries from Google Street View images! **How to play:** 1. Click "🎮 New Round" to load a random image from the GeoGuessr dataset 2. Choose one of the 4 proposed countries 3. Click "✅ Check Answer" to see if you beat the AI! **Model:** ConvNeXt V2 Base (61% accuracy, 51.77% F1-macro) """) with gr.Row(): with gr.Column(scale=2): image_display = gr.Image(type="pil", label="📸 Street View Image", interactive=False) start_btn = gr.Button("🎮 New Round", variant="primary", size="lg") with gr.Column(scale=1): scoreboard = gr.Markdown("🎮 Player: 0 | 🤖 AI: 0 | 🎯 Rounds: 0") reset_btn = gr.Button("🔄 Reset Game", variant="secondary") question = gr.Markdown("⬇️ Click 'New Round' to start!") options = gr.Radio( choices=[], label="🌍 Select the country:", visible=False ) submit_btn = gr.Button("✅ Check Answer", variant="primary", visible=False) result = gr.Markdown("") # Events start_btn.click( fn=new_round, inputs=[], outputs=[image_display, question, options, submit_btn, scoreboard, result] ) submit_btn.click( fn=check_answer, inputs=[options], outputs=[result, start_btn] ) reset_btn.click( fn=reset_game, outputs=[image_display, question, options, submit_btn, scoreboard, result] ) gr.Markdown(""" --- **Dataset:** [GeoGuessr by marcelomoreno26](https://huggingface.co/datasets/marcelomoreno26/geoguessr) **Model:** [ConvNeXt V2 GeoGuessr by Jordiett](https://huggingface.co/Jordiett/convnextv2-geoguessr) Images are randomly selected from the test set of the GeoGuessr dataset. """) if __name__ == "__main__": demo.launch()