Tas01 commited on
Commit
bc37146
·
verified ·
1 Parent(s): 1fa30cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -55
app.py CHANGED
@@ -145,8 +145,10 @@ class ImageStoryteller:
145
  'success': False
146
  }
147
 
 
 
148
  # def generate_story(self, analysis_result, creativity_level=0.7):
149
- # """Generate a story based on detected objects and scene"""
150
  # if self.llm_model is None:
151
  # return "Story generation model not available."
152
 
@@ -156,7 +158,7 @@ class ImageStoryteller:
156
  # scenes = [scene['type'] for scene in analysis_result['scenes']]
157
 
158
  # # Create a prompt for the LLM
159
- # objects_str = ", ".join(objects[:3]) # Use top 3 objects
160
  # scene_str = scenes[0] if scenes else "general scene"
161
 
162
  # # FIXED: Convert creativity_level to float if it's a tuple
@@ -166,16 +168,19 @@ class ImageStoryteller:
166
  # # Different prompt templates for creativity
167
  # if creativity_level > 0.8:
168
  # prompt = f"""Based on this image containing {objects_str} in a {scene_str}, write a creative and imaginative short story (3-4 paragraphs).
169
- # Make it engaging and add interesting details about the scene."""
170
  # elif creativity_level > 0.5:
171
  # prompt = f"""Create a short story about an image with {objects_str} in a {scene_str}.
172
- # Write 2-3 paragraphs that describe what might be happening in this scene."""
173
  # else:
174
  # prompt = f"""Describe what you see in an image containing {objects_str} in a {scene_str}.
175
- # Write a simple 1-2 paragraph description."""
176
 
177
- # # Format for the specific LLM
178
- # if "phi" in self.llm_model_id:
 
 
 
179
  # # Phi-2 specific formatting
180
  # formatted_prompt = f"Instruct: {prompt}\nOutput:"
181
  # elif "gemma" in self.llm_model_id:
@@ -183,39 +188,82 @@ class ImageStoryteller:
183
  # formatted_prompt = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
184
  # else:
185
  # # Generic formatting
186
- # formatted_prompt = f"Write a story: {prompt}\n\nStory:"
187
 
188
  # # Tokenize and generate
189
  # inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.llm_model.device)
190
 
191
  # with torch.no_grad():
192
- # outputs = self.llm_model.generate(
193
- # **inputs,
194
- # max_new_tokens=250, # Shorter for faster generation
195
- # temperature=creativity_level,
196
- # do_sample=True,
197
- # top_p=0.9,
198
- # repetition_penalty=1.1,
199
- # pad_token_id=self.tokenizer.eos_token_id
200
- # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
  # # Decode and clean up
203
  # story = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
204
 
205
- # # Remove the prompt from the beginning if present
206
- # if story.startswith(formatted_prompt):
 
 
 
 
 
 
 
 
 
 
 
 
207
  # story = story[len(formatted_prompt):].strip()
208
 
209
- # return story
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
  # except Exception as e:
212
  # print(f"Story generation failed: {e}")
213
  # objects_str = ", ".join(objects) if 'objects' in locals() else "unknown"
214
  # scene_str = scenes[0] if 'scenes' in locals() and scenes else "unknown scene"
215
- # return f"Failed to generate story. Detected objects: {objects_str} in a {scene_str}."
216
-
217
  def generate_story(self, analysis_result, creativity_level=0.7):
218
- """Generate a story based on detected objects and scene using Qwen"""
219
  if self.llm_model is None:
220
  return "Story generation model not available."
221
 
@@ -232,52 +280,65 @@ class ImageStoryteller:
232
  if isinstance(creativity_level, (tuple, list)):
233
  creativity_level = float(creativity_level[0])
234
 
235
- # Different prompt templates for creativity
236
  if creativity_level > 0.8:
237
- prompt = f"""Based on this image containing {objects_str} in a {scene_str}, write a creative and imaginative short story (3-4 paragraphs).
238
- Make it engaging and add interesting details about the scene. Story:"""
 
 
 
 
 
 
239
  elif creativity_level > 0.5:
240
- prompt = f"""Create a short story about an image with {objects_str} in a {scene_str}.
241
- Write 2-3 paragraphs that describe what might be happening in this scene. Story:"""
 
 
 
 
 
 
242
  else:
243
- prompt = f"""Describe what you see in an image containing {objects_str} in a {scene_str}.
244
- Write a simple 1-2 paragraph description. Description:"""
 
 
 
 
 
 
245
 
246
- # QWEN 1.8B SPECIFIC FORMATTING - SIMPLE AND EFFECTIVE
247
  if "qwen" in self.llm_model_id.lower():
248
- # Qwen works best with this simple format
249
  formatted_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
250
- elif "phi" in self.llm_model_id: # For Phi models
251
- # Phi-2 specific formatting
252
  formatted_prompt = f"Instruct: {prompt}\nOutput:"
253
  elif "gemma" in self.llm_model_id:
254
- # Gemma specific formatting
255
  formatted_prompt = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
256
  else:
257
- # Generic formatting
258
  formatted_prompt = f"{prompt}\n\n"
259
 
260
  # Tokenize and generate
261
  inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.llm_model.device)
262
 
263
  with torch.no_grad():
264
- # QWEN OPTIMIZED GENERATION PARAMETERS
265
  if "qwen" in self.llm_model_id.lower():
266
  outputs = self.llm_model.generate(
267
  **inputs,
268
- max_new_tokens=300, # Good length for stories
269
  temperature=creativity_level,
270
  do_sample=True,
271
  top_p=0.9,
272
  repetition_penalty=1.1,
273
  eos_token_id=self.tokenizer.eos_token_id,
274
  pad_token_id=self.tokenizer.eos_token_id,
275
- no_repeat_ngram_size=3 # Prevent repetition
276
  )
277
  else:
278
  outputs = self.llm_model.generate(
279
  **inputs,
280
- max_new_tokens=250,
281
  temperature=creativity_level,
282
  do_sample=True,
283
  top_p=0.9,
@@ -290,37 +351,61 @@ class ImageStoryteller:
290
 
291
  # Clean up Qwen specific tokens
292
  if "qwen" in self.llm_model_id.lower():
293
- # Remove the prompt and Qwen chat tokens
294
  story = story.replace(formatted_prompt, "").strip()
295
  story = story.replace("<|im_end|>", "").strip()
296
  story = story.replace("<|im_start|>", "").strip()
297
  story = story.replace("<|endoftext|>", "").strip()
298
-
299
- # Sometimes Qwen repeats, clean that up
300
- if "Story:" in story:
301
- story = story.split("Story:")[-1].strip()
302
- if "Description:" in story:
303
- story = story.split("Description:")[-1].strip()
304
  elif story.startswith(formatted_prompt):
305
  story = story[len(formatted_prompt):].strip()
306
 
307
- # Additional cleanup for any model
308
  story = story.strip()
309
 
310
- # If story is too short, try a simpler approach
311
- if len(story.split()) < 10:
312
- # Fallback: use a direct prompt
313
- simple_prompt = f"Tell me a story about {objects_str} in {scene_str}."
314
- simple_inputs = self.tokenizer(simple_prompt, return_tensors="pt").to(self.llm_model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  with torch.no_grad():
316
  simple_outputs = self.llm_model.generate(
317
  **simple_inputs,
318
- max_new_tokens=200,
319
  temperature=0.8,
320
  do_sample=True
321
  )
322
  story = self.tokenizer.decode(simple_outputs[0], skip_special_tokens=True)
323
- story = story.replace(simple_prompt, "").strip()
 
 
 
 
 
 
 
 
324
 
325
  return story
326
 
 
145
  'success': False
146
  }
147
 
148
+
149
+
150
  # def generate_story(self, analysis_result, creativity_level=0.7):
151
+ # """Generate a story based on detected objects and scene using Qwen"""
152
  # if self.llm_model is None:
153
  # return "Story generation model not available."
154
 
 
158
  # scenes = [scene['type'] for scene in analysis_result['scenes']]
159
 
160
  # # Create a prompt for the LLM
161
+ # objects_str = ", ".join(objects) # Use top 3 objects
162
  # scene_str = scenes[0] if scenes else "general scene"
163
 
164
  # # FIXED: Convert creativity_level to float if it's a tuple
 
168
  # # Different prompt templates for creativity
169
  # if creativity_level > 0.8:
170
  # prompt = f"""Based on this image containing {objects_str} in a {scene_str}, write a creative and imaginative short story (3-4 paragraphs).
171
+ # Make it engaging and add interesting details about the scene. Story:"""
172
  # elif creativity_level > 0.5:
173
  # prompt = f"""Create a short story about an image with {objects_str} in a {scene_str}.
174
+ # Write 2-3 paragraphs that describe what might be happening in this scene. Story:"""
175
  # else:
176
  # prompt = f"""Describe what you see in an image containing {objects_str} in a {scene_str}.
177
+ # Write a simple 1-2 paragraph description. Description:"""
178
 
179
+ # # QWEN 1.8B SPECIFIC FORMATTING - SIMPLE AND EFFECTIVE
180
+ # if "qwen" in self.llm_model_id.lower():
181
+ # # Qwen works best with this simple format
182
+ # formatted_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
183
+ # elif "phi" in self.llm_model_id: # For Phi models
184
  # # Phi-2 specific formatting
185
  # formatted_prompt = f"Instruct: {prompt}\nOutput:"
186
  # elif "gemma" in self.llm_model_id:
 
188
  # formatted_prompt = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
189
  # else:
190
  # # Generic formatting
191
+ # formatted_prompt = f"{prompt}\n\n"
192
 
193
  # # Tokenize and generate
194
  # inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.llm_model.device)
195
 
196
  # with torch.no_grad():
197
+ # # QWEN OPTIMIZED GENERATION PARAMETERS
198
+ # if "qwen" in self.llm_model_id.lower():
199
+ # outputs = self.llm_model.generate(
200
+ # **inputs,
201
+ # max_new_tokens=300, # Good length for stories
202
+ # temperature=creativity_level,
203
+ # do_sample=True,
204
+ # top_p=0.9,
205
+ # repetition_penalty=1.1,
206
+ # eos_token_id=self.tokenizer.eos_token_id,
207
+ # pad_token_id=self.tokenizer.eos_token_id,
208
+ # no_repeat_ngram_size=3 # Prevent repetition
209
+ # )
210
+ # else:
211
+ # outputs = self.llm_model.generate(
212
+ # **inputs,
213
+ # max_new_tokens=250,
214
+ # temperature=creativity_level,
215
+ # do_sample=True,
216
+ # top_p=0.9,
217
+ # repetition_penalty=1.1,
218
+ # pad_token_id=self.tokenizer.eos_token_id
219
+ # )
220
 
221
  # # Decode and clean up
222
  # story = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
223
 
224
+ # # Clean up Qwen specific tokens
225
+ # if "qwen" in self.llm_model_id.lower():
226
+ # # Remove the prompt and Qwen chat tokens
227
+ # story = story.replace(formatted_prompt, "").strip()
228
+ # story = story.replace("<|im_end|>", "").strip()
229
+ # story = story.replace("<|im_start|>", "").strip()
230
+ # story = story.replace("<|endoftext|>", "").strip()
231
+
232
+ # # Sometimes Qwen repeats, clean that up
233
+ # if "Story:" in story:
234
+ # story = story.split("Story:")[-1].strip()
235
+ # if "Description:" in story:
236
+ # story = story.split("Description:")[-1].strip()
237
+ # elif story.startswith(formatted_prompt):
238
  # story = story[len(formatted_prompt):].strip()
239
 
240
+ # # Additional cleanup for any model
241
+ # story = story.strip()
242
+
243
+ # # If story is too short, try a simpler approach
244
+ # if len(story.split()) < 10:
245
+ # # Fallback: use a direct prompt
246
+ # simple_prompt = f"Tell me a story about {objects_str} in {scene_str}."
247
+ # simple_inputs = self.tokenizer(simple_prompt, return_tensors="pt").to(self.llm_model.device)
248
+ # with torch.no_grad():
249
+ # simple_outputs = self.llm_model.generate(
250
+ # **simple_inputs,
251
+ # max_new_tokens=200,
252
+ # temperature=0.8,
253
+ # do_sample=True
254
+ # )
255
+ # story = self.tokenizer.decode(simple_outputs[0], skip_special_tokens=True)
256
+ # story = story.replace(simple_prompt, "").strip()
257
+
258
+ # # return story
259
 
260
  # except Exception as e:
261
  # print(f"Story generation failed: {e}")
262
  # objects_str = ", ".join(objects) if 'objects' in locals() else "unknown"
263
  # scene_str = scenes[0] if 'scenes' in locals() and scenes else "unknown scene"
264
+ # return f"Failed to generate story. Detected objects: {objects_str} in a {scene_str}. Error: {str(e)}"
 
265
  def generate_story(self, analysis_result, creativity_level=0.7):
266
+ """Generate a story with caption based on detected objects and scene using Qwen"""
267
  if self.llm_model is None:
268
  return "Story generation model not available."
269
 
 
280
  if isinstance(creativity_level, (tuple, list)):
281
  creativity_level = float(creativity_level[0])
282
 
283
+ # Enhanced prompt with caption generation
284
  if creativity_level > 0.8:
285
+ prompt = f"""Based on this image containing {objects_str} in a {scene_str}:
286
+
287
+ 1. First, write a catchy 5-7 word YouTube-style caption (engaging, attention-grabbing)
288
+ 2. Then, write a creative and imaginative short story (3-4 paragraphs)
289
+
290
+ Format exactly like this:
291
+ CAPTION: [your catchy caption here]
292
+ STORY: [your creative story here]"""
293
  elif creativity_level > 0.5:
294
+ prompt = f"""For an image with {objects_str} in a {scene_str}:
295
+
296
+ 1. Create a short, interesting caption (5-7 words)
297
+ 2. Write a 2-3 paragraph story about what's happening in this scene
298
+
299
+ Format:
300
+ CAPTION: [your caption here]
301
+ STORY: [your story here]"""
302
  else:
303
+ prompt = f"""Describe an image containing {objects_str} in a {scene_str}:
304
+
305
+ 1. Give a simple, descriptive caption
306
+ 2. Write a 1-2 paragraph description
307
+
308
+ Format:
309
+ CAPTION: [caption here]
310
+ STORY: [description here]"""
311
 
312
+ # QWEN 1.8B SPECIFIC FORMATTING
313
  if "qwen" in self.llm_model_id.lower():
 
314
  formatted_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
315
+ elif "phi" in self.llm_model_id:
 
316
  formatted_prompt = f"Instruct: {prompt}\nOutput:"
317
  elif "gemma" in self.llm_model_id:
 
318
  formatted_prompt = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
319
  else:
 
320
  formatted_prompt = f"{prompt}\n\n"
321
 
322
  # Tokenize and generate
323
  inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.llm_model.device)
324
 
325
  with torch.no_grad():
 
326
  if "qwen" in self.llm_model_id.lower():
327
  outputs = self.llm_model.generate(
328
  **inputs,
329
+ max_new_tokens=350, # Increased for caption + story
330
  temperature=creativity_level,
331
  do_sample=True,
332
  top_p=0.9,
333
  repetition_penalty=1.1,
334
  eos_token_id=self.tokenizer.eos_token_id,
335
  pad_token_id=self.tokenizer.eos_token_id,
336
+ no_repeat_ngram_size=3
337
  )
338
  else:
339
  outputs = self.llm_model.generate(
340
  **inputs,
341
+ max_new_tokens=300,
342
  temperature=creativity_level,
343
  do_sample=True,
344
  top_p=0.9,
 
351
 
352
  # Clean up Qwen specific tokens
353
  if "qwen" in self.llm_model_id.lower():
 
354
  story = story.replace(formatted_prompt, "").strip()
355
  story = story.replace("<|im_end|>", "").strip()
356
  story = story.replace("<|im_start|>", "").strip()
357
  story = story.replace("<|endoftext|>", "").strip()
 
 
 
 
 
 
358
  elif story.startswith(formatted_prompt):
359
  story = story[len(formatted_prompt):].strip()
360
 
361
+ # Additional cleanup
362
  story = story.strip()
363
 
364
+ # Ensure proper formatting for caption and story
365
+ lines = story.split('\n')
366
+ formatted_lines = []
367
+ for line in lines:
368
+ line = line.strip()
369
+ if line and not line.startswith('CAPTION:') and not line.startswith('STORY:'):
370
+ # If we have caption/story markers but missing the prefix
371
+ if 'caption:' in line.lower() and 'caption:' not in line:
372
+ line = 'CAPTION: ' + line.split('caption:')[-1].strip()
373
+ elif 'story:' in line.lower() and 'story:' not in line:
374
+ line = 'STORY: ' + line.split('story:')[-1].strip()
375
+ formatted_lines.append(line)
376
+
377
+ story = '\n'.join(formatted_lines)
378
+
379
+ # Add visual separator if not already present
380
+ if 'STORY:' in story:
381
+ parts = story.split('STORY:', 1)
382
+ if len(parts) == 2:
383
+ caption_part = parts[0].replace('CAPTION:', '').strip()
384
+ story_part = parts[1].strip()
385
+ # Format with separator
386
+ story = f"{caption_part}\n{'─' * 40}\n{story_part}"
387
+
388
+ # Fallback if generation is too short
389
+ if len(story.split()) < 15:
390
+ fallback_prompt = f"Create a caption and story for {objects_str} in {scene_str}."
391
+ simple_inputs = self.tokenizer(fallback_prompt, return_tensors="pt").to(self.llm_model.device)
392
  with torch.no_grad():
393
  simple_outputs = self.llm_model.generate(
394
  **simple_inputs,
395
+ max_new_tokens=250,
396
  temperature=0.8,
397
  do_sample=True
398
  )
399
  story = self.tokenizer.decode(simple_outputs[0], skip_special_tokens=True)
400
+ story = story.replace(fallback_prompt, "").strip()
401
+ # Add separator
402
+ sentences = story.split('. ')
403
+ if sentences:
404
+ caption = sentences[0].strip()
405
+ if not caption.endswith('.'):
406
+ caption += '.'
407
+ rest_of_story = '. '.join(sentences[1:]) if len(sentences) > 1 else story
408
+ story = f"{caption}\n{'─' * 40}\n{rest_of_story}"
409
 
410
  return story
411