Tas01 commited on
Commit
c95b9e0
·
verified ·
1 Parent(s): 4888ca6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -271
app.py CHANGED
@@ -147,121 +147,6 @@ class ImageStoryteller:
147
 
148
 
149
 
150
- # def generate_story(self, analysis_result, creativity_level=0.7):
151
- # """Generate a story based on detected objects and scene using Qwen"""
152
- # if self.llm_model is None:
153
- # return "Story generation model not available."
154
-
155
- # try:
156
- # # Extract detected objects and scene
157
- # objects = [obj['name'] for obj in analysis_result['objects']]
158
- # scenes = [scene['type'] for scene in analysis_result['scenes']]
159
-
160
- # # Create a prompt for the LLM
161
- # objects_str = ", ".join(objects) # Use top 3 objects
162
- # scene_str = scenes[0] if scenes else "general scene"
163
-
164
- # # FIXED: Convert creativity_level to float if it's a tuple
165
- # if isinstance(creativity_level, (tuple, list)):
166
- # creativity_level = float(creativity_level[0])
167
-
168
- # # Different prompt templates for creativity
169
- # if creativity_level > 0.8:
170
- # prompt = f"""Based on this image containing {objects_str} in a {scene_str}, write a creative and imaginative short story (3-4 paragraphs).
171
- # Make it engaging and add interesting details about the scene. Story:"""
172
- # elif creativity_level > 0.5:
173
- # prompt = f"""Create a short story about an image with {objects_str} in a {scene_str}.
174
- # Write 2-3 paragraphs that describe what might be happening in this scene. Story:"""
175
- # else:
176
- # prompt = f"""Describe what you see in an image containing {objects_str} in a {scene_str}.
177
- # Write a simple 1-2 paragraph description. Description:"""
178
-
179
- # # QWEN 1.8B SPECIFIC FORMATTING - SIMPLE AND EFFECTIVE
180
- # if "qwen" in self.llm_model_id.lower():
181
- # # Qwen works best with this simple format
182
- # formatted_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
183
- # elif "phi" in self.llm_model_id: # For Phi models
184
- # # Phi-2 specific formatting
185
- # formatted_prompt = f"Instruct: {prompt}\nOutput:"
186
- # elif "gemma" in self.llm_model_id:
187
- # # Gemma specific formatting
188
- # formatted_prompt = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
189
- # else:
190
- # # Generic formatting
191
- # formatted_prompt = f"{prompt}\n\n"
192
-
193
- # # Tokenize and generate
194
- # inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.llm_model.device)
195
-
196
- # with torch.no_grad():
197
- # # QWEN OPTIMIZED GENERATION PARAMETERS
198
- # if "qwen" in self.llm_model_id.lower():
199
- # outputs = self.llm_model.generate(
200
- # **inputs,
201
- # max_new_tokens=300, # Good length for stories
202
- # temperature=creativity_level,
203
- # do_sample=True,
204
- # top_p=0.9,
205
- # repetition_penalty=1.1,
206
- # eos_token_id=self.tokenizer.eos_token_id,
207
- # pad_token_id=self.tokenizer.eos_token_id,
208
- # no_repeat_ngram_size=3 # Prevent repetition
209
- # )
210
- # else:
211
- # outputs = self.llm_model.generate(
212
- # **inputs,
213
- # max_new_tokens=250,
214
- # temperature=creativity_level,
215
- # do_sample=True,
216
- # top_p=0.9,
217
- # repetition_penalty=1.1,
218
- # pad_token_id=self.tokenizer.eos_token_id
219
- # )
220
-
221
- # # Decode and clean up
222
- # story = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
223
-
224
- # # Clean up Qwen specific tokens
225
- # if "qwen" in self.llm_model_id.lower():
226
- # # Remove the prompt and Qwen chat tokens
227
- # story = story.replace(formatted_prompt, "").strip()
228
- # story = story.replace("<|im_end|>", "").strip()
229
- # story = story.replace("<|im_start|>", "").strip()
230
- # story = story.replace("<|endoftext|>", "").strip()
231
-
232
- # # Sometimes Qwen repeats, clean that up
233
- # if "Story:" in story:
234
- # story = story.split("Story:")[-1].strip()
235
- # if "Description:" in story:
236
- # story = story.split("Description:")[-1].strip()
237
- # elif story.startswith(formatted_prompt):
238
- # story = story[len(formatted_prompt):].strip()
239
-
240
- # # Additional cleanup for any model
241
- # story = story.strip()
242
-
243
- # # If story is too short, try a simpler approach
244
- # if len(story.split()) < 10:
245
- # # Fallback: use a direct prompt
246
- # simple_prompt = f"Tell me a story about {objects_str} in {scene_str}."
247
- # simple_inputs = self.tokenizer(simple_prompt, return_tensors="pt").to(self.llm_model.device)
248
- # with torch.no_grad():
249
- # simple_outputs = self.llm_model.generate(
250
- # **simple_inputs,
251
- # max_new_tokens=200,
252
- # temperature=0.8,
253
- # do_sample=True
254
- # )
255
- # story = self.tokenizer.decode(simple_outputs[0], skip_special_tokens=True)
256
- # story = story.replace(simple_prompt, "").strip()
257
-
258
- # # return story
259
-
260
- # except Exception as e:
261
- # print(f"Story generation failed: {e}")
262
- # objects_str = ", ".join(objects) if 'objects' in locals() else "unknown"
263
- # scene_str = scenes[0] if 'scenes' in locals() and scenes else "unknown scene"
264
- # return f"Failed to generate story. Detected objects: {objects_str} in a {scene_str}. Error: {str(e)}"
265
  def generate_story(self, analysis_result, creativity_level=0.7):
266
  """Generate a story with caption based on detected objects and scene using Qwen"""
267
  if self.llm_model is None:
@@ -402,65 +287,7 @@ class ImageStoryteller:
402
 
403
  return formatted_text
404
 
405
- # FIXED: Added the missing method
406
- # def create_story_overlay(self, image, story):
407
- # """Create a simple overlay showing story - returns just the story text"""
408
- # # If you want to create an image with text, you can implement it here
409
- # # For now, let's just return the story text
410
- # return story
411
-
412
- # def create_story_overlay(self, image, story):
413
- # """Create story overlay as separate black image with bigger fonts"""
414
- # img_np = np.array(image)
415
- # height, width = 800, 800#img_np.shape[:2]
416
-
417
- # # Create a separate black image for the story (1/3 of original height)
418
- # overlay_height = height // 1
419
- # overlay = np.zeros((overlay_height, width, 3), dtype=np.uint8)
420
-
421
- # # Add text to the black overlay with bigger fonts
422
- # font = cv2.FONT_HERSHEY_SIMPLEX
423
- # font_scale = 1 # Much bigger font
424
- # font_color = (255, 255, 255) # White text
425
- # thickness = 1 # Thicker text
426
- # line_spacing = 25 # More spacing for bigger text
427
-
428
- # # Split story into lines (max 40 characters per line for bigger text)
429
- # words = story.split()
430
- # lines = []
431
- # current_line = ""
432
-
433
- # for word in words:
434
- # if len(current_line + word) <= 40:
435
- # current_line += word + " "
436
- # else:
437
- # lines.append(current_line.strip())
438
- # current_line = word + " "
439
- # if current_line:
440
- # lines.append(current_line.strip())
441
-
442
- # # Limit to 5 lines maximum for bigger text
443
- # # if len(lines) > 5: # If you want to keep some limit but indicate truncation
444
- # # lines = lines[:5]
445
- # # lines.append("... [Story continues]") # Indicate truncation
446
-
447
-
448
- # # Calculate text block height for centering
449
- # total_text_height = len(lines) * line_spacing
450
- # start_y = (overlay_height - total_text_height) // + 60
451
-
452
- # # Add each line of text, centered
453
- # y_offset = start_y
454
- # for line in lines:
455
- # # Calculate text size for centering
456
- # text_size = cv2.getTextSize(line, font, font_scale, thickness)[0]
457
- # text_x = (width - text_size[0]) // 2
458
-
459
- # cv2.putText(overlay, line, (text_x, y_offset),
460
- # font, font_scale, font_color, thickness, cv2.LINE_AA)
461
- # y_offset += line_spacing
462
-
463
- # return Image.fromarray(overlay)
464
 
465
  def remove_background(self, image):
466
  """Remove background using rembg"""
@@ -524,7 +351,7 @@ class ImageStoryteller:
524
  analysis_result = self.analyze_image_with_clip(image)
525
 
526
  # Generate story
527
- story = self.generate_story(analysis_result, image.size)
528
 
529
  # # Create analysis overlay
530
  # analysis_image = self.create_analysis_overlay(image, analysis_result)
@@ -578,55 +405,7 @@ with gr.Blocks(title="Who says AI isn’t creative? Watch it turn a single image
578
  # Load example images
579
  example_images_list = get_example_images()
580
 
581
- # # Custom CSS to remove gallery selection frames
582
- # #custom_css = """
583
- # #.gradio-container {
584
- # # max-height: 95vh !important;
585
- # # overflow-y: auto !important;
586
- # #}
587
-
588
- # #blocks-container {
589
- # max-height: 100% !important;
590
- # overflow: auto !important;
591
- # }
592
-
593
- # #.gallery .wrap.contain .grid .wrap {
594
- # border: none !important;
595
- # box-shadow: none !important;
596
- # outline: none !important;
597
- # }
598
- # .gallery .wrap.contain .grid .wrap.selected {
599
- # border: none !important;
600
- # box-shadow: none !important;
601
- # outline: none !important;
602
- # }
603
- # .gallery .thumbnail {
604
- # border: none !important;
605
- # box-shadow: none !important;
606
- # outline: none !important;
607
- # }
608
- # .gallery .thumbnail.selected {
609
- # border: none !important;
610
- # box-shadow: none !important;
611
- # outline: none !important;
612
- # }
613
- # .gallery .wrap.gradio-image {
614
- # border: none !important;
615
- # box-shadow: none !important;
616
- # outline: none !important;
617
- # }
618
- # .gallery .wrap.gradio-image.selected {
619
- # border: none !important;
620
- # box-shadow: none !important;
621
- # outline: none !important;
622
- # }
623
- # /* Prevent infinite expansion */
624
- # .panel {
625
- # max-height: 80vh !important;
626
- # overflow-y: auto !important;
627
- # }
628
- # #"""
629
-
630
 
631
  custom_css = """
632
  <style>
@@ -655,53 +434,7 @@ with gr.Blocks(title="Who says AI isn’t creative? Watch it turn a single image
655
  </style>
656
  """
657
 
658
- # Add this JavaScript separately
659
- # javascript = """
660
- # <script>
661
- # document.addEventListener('DOMContentLoaded', function() {
662
- # // Force container height immediately
663
- # const forceHeight = function() {
664
- # const containers = document.querySelectorAll('.gradio-container, .container, #blocks-container');
665
- # containers.forEach(container => {
666
- # container.style.height = '100vh';
667
- # container.style.maxHeight = '100vh';
668
- # container.style.overflowY = 'auto';
669
- # });
670
- # };
671
-
672
- # // Run immediately
673
- # forceHeight();
674
-
675
- # // Run again after a short delay to catch dynamic content
676
- # setTimeout(forceHeight, 100);
677
- # setTimeout(forceHeight, 500);
678
-
679
- # // Monitor for ANY DOM changes
680
- # const observer = new MutationObserver(function(mutations) {
681
- # forceHeight();
682
- # // Constrain any new elements
683
- # mutations.forEach(function(mutation) {
684
- # mutation.addedNodes.forEach(function(node) {
685
- # if (node.nodeType === 1) {
686
- # node.style.maxHeight = '100%';
687
- # if (node.querySelectorAll) {
688
- # node.querySelectorAll('*').forEach(child => {
689
- # child.style.maxHeight = '100%';
690
- # });
691
- # }
692
- # }
693
- # });
694
- # });
695
- # });
696
-
697
- # observer.observe(document.body, {
698
- # childList: true,
699
- # subtree: true,
700
- # attributes: true,
701
- # attributeFilter: ['style', 'class']
702
- # });
703
- # });
704
- # </script>
705
  javascript = """
706
  <script>
707
  document.addEventListener('DOMContentLoaded', function() {
 
147
 
148
 
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  def generate_story(self, analysis_result, creativity_level=0.7):
151
  """Generate a story with caption based on detected objects and scene using Qwen"""
152
  if self.llm_model is None:
 
287
 
288
  return formatted_text
289
 
290
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
  def remove_background(self, image):
293
  """Remove background using rembg"""
 
351
  analysis_result = self.analyze_image_with_clip(image)
352
 
353
  # Generate story
354
+ story = self.generate_story(analysis_result, creativity_level=0.7)
355
 
356
  # # Create analysis overlay
357
  # analysis_image = self.create_analysis_overlay(image, analysis_result)
 
405
  # Load example images
406
  example_images_list = get_example_images()
407
 
408
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
 
410
  custom_css = """
411
  <style>
 
434
  </style>
435
  """
436
 
437
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
  javascript = """
439
  <script>
440
  document.addEventListener('DOMContentLoaded', function() {