Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| from flask import Flask, request, jsonify, send_file | |
| from werkzeug.utils import secure_filename | |
| from PIL import Image | |
| import io | |
| import zipfile | |
| from diffusers import ShapEImg2ImgPipeline | |
| from diffusers.utils import export_to_obj | |
| app = Flask(__name__) | |
| # Configure upload folder | |
| UPLOAD_FOLDER = 'uploads' | |
| RESULTS_FOLDER = 'results' | |
| ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'} | |
| os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
| os.makedirs(RESULTS_FOLDER, exist_ok=True) | |
| app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER | |
| app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max | |
| # Initialize the model (will download on first run) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| pipe = ShapEImg2ImgPipeline.from_pretrained("openai/shap-e-img2img", torch_dtype=torch.float16) | |
| pipe = pipe.to(device) | |
| def allowed_file(filename): | |
| return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS | |
| def health_check(): | |
| return jsonify({"status": "healthy", "model": "Shap-E Image to 3D"}), 200 | |
| def convert_image_to_3d(): | |
| # Check if image is in the request | |
| if 'image' not in request.files: | |
| return jsonify({"error": "No image provided"}), 400 | |
| file = request.files['image'] | |
| if file.filename == '': | |
| return jsonify({"error": "No image selected"}), 400 | |
| if not allowed_file(file.filename): | |
| return jsonify({"error": f"File type not allowed. Supported types: {', '.join(ALLOWED_EXTENSIONS)}"}), 400 | |
| # Get optional parameters | |
| guidance_scale = float(request.form.get('guidance_scale', 3.0)) | |
| num_inference_steps = int(request.form.get('num_inference_steps', 64)) | |
| output_format = request.form.get('output_format', 'obj').lower() | |
| # Validate output format | |
| if output_format not in ['obj', 'glb']: | |
| return jsonify({"error": "Unsupported output format. Use 'obj' or 'glb'"}), 400 | |
| try: | |
| # Process image | |
| filename = secure_filename(file.filename) | |
| filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) | |
| file.save(filepath) | |
| # Open image | |
| image = Image.open(filepath).convert("RGB") | |
| # Generate 3D model | |
| images = pipe( | |
| image, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| output_type="mesh", | |
| ).images | |
| # Create unique output directory | |
| import uuid | |
| output_id = str(uuid.uuid4()) | |
| output_dir = os.path.join(RESULTS_FOLDER, output_id) | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Export to requested format | |
| if output_format == 'obj': | |
| obj_path = os.path.join(output_dir, "model.obj") | |
| export_to_obj(images[0], obj_path) | |
| # Create a zip file with OBJ and MTL | |
| zip_path = os.path.join(output_dir, "model.zip") | |
| with zipfile.ZipFile(zip_path, 'w') as zipf: | |
| zipf.write(obj_path, arcname="model.obj") | |
| mtl_path = os.path.join(output_dir, "model.mtl") | |
| if os.path.exists(mtl_path): | |
| zipf.write(mtl_path, arcname="model.mtl") | |
| return send_file(zip_path, as_attachment=True, download_name="model.zip") | |
| elif output_format == 'glb': | |
| # For GLB format, we need to convert the mesh | |
| from trimesh import Trimesh | |
| mesh = images[0] | |
| vertices = mesh.verts | |
| faces = mesh.faces | |
| # Create a trimesh object | |
| trimesh_obj = Trimesh(vertices=vertices, faces=faces) | |
| # Export as GLB | |
| glb_path = os.path.join(output_dir, "model.glb") | |
| trimesh_obj.export(glb_path) | |
| return send_file(glb_path, as_attachment=True, download_name="model.glb") | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| def index(): | |
| return """ | |
| <html> | |
| <head> | |
| <title>Image to 3D Model Converter</title> | |
| <style> | |
| body { font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; } | |
| h1 { color: #333; } | |
| form { margin: 20px 0; padding: 20px; border: 1px solid #ddd; border-radius: 5px; } | |
| label { display: block; margin: 10px 0 5px; } | |
| input, select { margin-bottom: 10px; padding: 8px; width: 100%; } | |
| button { background: #4CAF50; color: white; border: none; padding: 10px 15px; cursor: pointer; } | |
| .api-info { background: #f5f5f5; padding: 15px; border-radius: 5px; } | |
| pre { background: #eee; padding: 10px; overflow-x: auto; } | |
| </style> | |
| </head> | |
| <body> | |
| <h1>Image to 3D Model Converter</h1> | |
| <form action="/convert" method="post" enctype="multipart/form-data"> | |
| <label for="image">Upload Image:</label> | |
| <input type="file" id="image" name="image" accept=".png,.jpg,.jpeg" required> | |
| <label for="guidance_scale">Guidance Scale (1.0-5.0):</label> | |
| <input type="number" id="guidance_scale" name="guidance_scale" min="1.0" max="5.0" step="0.1" value="3.0"> | |
| <label for="num_inference_steps">Inference Steps (32-128):</label> | |
| <input type="number" id="num_inference_steps" name="num_inference_steps" min="32" max="128" value="64"> | |
| <label for="output_format">Output Format:</label> | |
| <select id="output_format" name="output_format"> | |
| <option value="obj">OBJ (for Unity)</option> | |
| <option value="glb">GLB (for Three.js/Unreal)</option> | |
| </select> | |
| <button type="submit">Convert to 3D</button> | |
| </form> | |
| <div class="api-info"> | |
| <h2>API Documentation</h2> | |
| <p>Endpoint: <code>/convert</code> (POST)</p> | |
| <p>Parameters:</p> | |
| <ul> | |
| <li><code>image</code>: Image file (required)</li> | |
| <li><code>guidance_scale</code>: Float between 1.0-5.0 (default: 3.0)</li> | |
| <li><code>num_inference_steps</code>: Integer between 32-128 (default: 64)</li> | |
| <li><code>output_format</code>: "obj" or "glb" (default: "obj")</li> | |
| </ul> | |
| <p>Example curl request:</p> | |
| <pre>curl -X POST -F "image=@your_image.jpg" -F "output_format=obj" http://localhost:5000/convert -o model.zip</pre> | |
| </div> | |
| </body> | |
| </html> | |
| """ | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=int(os.environ.get('PORT', 5000))) | |