lkhl commited on
Commit
ba32678
Β·
verified Β·
1 Parent(s): b58be07

Update rynnec/__init__.py

Browse files
Files changed (1) hide show
  1. rynnec/__init__.py +2 -2
rynnec/__init__.py CHANGED
@@ -145,7 +145,7 @@ def mm_infer(images_or_videos, vlprocessor, instruct, model, tokenizer, modal='v
145
 
146
  return outputs
147
 
148
- def mm_infer_segmentation(images_or_videos, vlprocessor, instruct, model, tokenizer, modal='video', seg_start_idx=0, **kwargs):
149
 
150
  image2maskids = kwargs.get('image2maskids', [])
151
  img_size=1024
@@ -264,6 +264,6 @@ def mm_infer_segmentation(images_or_videos, vlprocessor, instruct, model, tokeni
264
  )
265
 
266
  outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
267
- pred_masks_sigmoid = pred_masks.sigmoid()>0.5
268
 
269
  return outputs, pred_masks_sigmoid
 
145
 
146
  return outputs
147
 
148
+ def mm_infer_segmentation(images_or_videos, vlprocessor, instruct, model, tokenizer, modal='video', seg_start_idx=0, mask_threshold=0.5, **kwargs):
149
 
150
  image2maskids = kwargs.get('image2maskids', [])
151
  img_size=1024
 
264
  )
265
 
266
  outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
267
+ pred_masks_sigmoid = pred_masks.sigmoid() > mask_threshold
268
 
269
  return outputs, pred_masks_sigmoid