Spaces:
Sleeping
Sleeping
File size: 6,998 Bytes
c105678 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import os
import torch
from flask import Flask, request, jsonify, send_file
from werkzeug.utils import secure_filename
from PIL import Image
import io
import zipfile
from diffusers import ShapEImg2ImgPipeline
from diffusers.utils import export_to_obj
app = Flask(__name__)
# Configure upload folder
UPLOAD_FOLDER = 'uploads'
RESULTS_FOLDER = 'results'
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(RESULTS_FOLDER, exist_ok=True)
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max
# Initialize the model (will download on first run)
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = ShapEImg2ImgPipeline.from_pretrained("openai/shap-e-img2img", torch_dtype=torch.float16)
pipe = pipe.to(device)
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
@app.route('/health', methods=['GET'])
def health_check():
return jsonify({"status": "healthy", "model": "Shap-E Image to 3D"}), 200
@app.route('/convert', methods=['POST'])
def convert_image_to_3d():
# Check if image is in the request
if 'image' not in request.files:
return jsonify({"error": "No image provided"}), 400
file = request.files['image']
if file.filename == '':
return jsonify({"error": "No image selected"}), 400
if not allowed_file(file.filename):
return jsonify({"error": f"File type not allowed. Supported types: {', '.join(ALLOWED_EXTENSIONS)}"}), 400
# Get optional parameters
guidance_scale = float(request.form.get('guidance_scale', 3.0))
num_inference_steps = int(request.form.get('num_inference_steps', 64))
output_format = request.form.get('output_format', 'obj').lower()
# Validate output format
if output_format not in ['obj', 'glb']:
return jsonify({"error": "Unsupported output format. Use 'obj' or 'glb'"}), 400
try:
# Process image
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(filepath)
# Open image
image = Image.open(filepath).convert("RGB")
# Generate 3D model
images = pipe(
image,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
output_type="mesh",
).images
# Create unique output directory
import uuid
output_id = str(uuid.uuid4())
output_dir = os.path.join(RESULTS_FOLDER, output_id)
os.makedirs(output_dir, exist_ok=True)
# Export to requested format
if output_format == 'obj':
obj_path = os.path.join(output_dir, "model.obj")
export_to_obj(images[0], obj_path)
# Create a zip file with OBJ and MTL
zip_path = os.path.join(output_dir, "model.zip")
with zipfile.ZipFile(zip_path, 'w') as zipf:
zipf.write(obj_path, arcname="model.obj")
mtl_path = os.path.join(output_dir, "model.mtl")
if os.path.exists(mtl_path):
zipf.write(mtl_path, arcname="model.mtl")
return send_file(zip_path, as_attachment=True, download_name="model.zip")
elif output_format == 'glb':
# For GLB format, we need to convert the mesh
from trimesh import Trimesh
mesh = images[0]
vertices = mesh.verts
faces = mesh.faces
# Create a trimesh object
trimesh_obj = Trimesh(vertices=vertices, faces=faces)
# Export as GLB
glb_path = os.path.join(output_dir, "model.glb")
trimesh_obj.export(glb_path)
return send_file(glb_path, as_attachment=True, download_name="model.glb")
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route('/', methods=['GET'])
def index():
return """
<html>
<head>
<title>Image to 3D Model Converter</title>
<style>
body { font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; }
h1 { color: #333; }
form { margin: 20px 0; padding: 20px; border: 1px solid #ddd; border-radius: 5px; }
label { display: block; margin: 10px 0 5px; }
input, select { margin-bottom: 10px; padding: 8px; width: 100%; }
button { background: #4CAF50; color: white; border: none; padding: 10px 15px; cursor: pointer; }
.api-info { background: #f5f5f5; padding: 15px; border-radius: 5px; }
pre { background: #eee; padding: 10px; overflow-x: auto; }
</style>
</head>
<body>
<h1>Image to 3D Model Converter</h1>
<form action="/convert" method="post" enctype="multipart/form-data">
<label for="image">Upload Image:</label>
<input type="file" id="image" name="image" accept=".png,.jpg,.jpeg" required>
<label for="guidance_scale">Guidance Scale (1.0-5.0):</label>
<input type="number" id="guidance_scale" name="guidance_scale" min="1.0" max="5.0" step="0.1" value="3.0">
<label for="num_inference_steps">Inference Steps (32-128):</label>
<input type="number" id="num_inference_steps" name="num_inference_steps" min="32" max="128" value="64">
<label for="output_format">Output Format:</label>
<select id="output_format" name="output_format">
<option value="obj">OBJ (for Unity)</option>
<option value="glb">GLB (for Three.js/Unreal)</option>
</select>
<button type="submit">Convert to 3D</button>
</form>
<div class="api-info">
<h2>API Documentation</h2>
<p>Endpoint: <code>/convert</code> (POST)</p>
<p>Parameters:</p>
<ul>
<li><code>image</code>: Image file (required)</li>
<li><code>guidance_scale</code>: Float between 1.0-5.0 (default: 3.0)</li>
<li><code>num_inference_steps</code>: Integer between 32-128 (default: 64)</li>
<li><code>output_format</code>: "obj" or "glb" (default: "obj")</li>
</ul>
<p>Example curl request:</p>
<pre>curl -X POST -F "image=@your_image.jpg" -F "output_format=obj" http://localhost:5000/convert -o model.zip</pre>
</div>
</body>
</html>
"""
if __name__ == '__main__':
app.run(host='0.0.0.0', port=int(os.environ.get('PORT', 5000)))
|