| | import json |
| | import os |
| | from io import BytesIO |
| | from pathlib import Path |
| | from typing import Any, Dict, List, Literal, Optional, Union |
| |
|
| | import requests |
| | import torch |
| | from PIL import Image |
| | from torch import nn |
| | from transformers import AutoConfig, AutoModel, AutoProcessor |
| |
|
| |
|
| | class Transformer(nn.Module): |
| |
|
| | save_in_root: bool = True |
| |
|
| | def __init__( |
| | self, |
| | model_name_or_path: str = "jinaai/jina-embeddings-v4", |
| | max_seq_length: Optional[int] = None, |
| | config_args: Optional[Dict[str, Any]] = None, |
| | model_args: Optional[Dict[str, Any]] = None, |
| | tokenizer_args: Optional[Dict[str, Any]] = None, |
| | cache_dir: Optional[str] = None, |
| | backend: Literal["torch", "onnx", "openvino"] = "torch", |
| | **kwargs, |
| | ) -> None: |
| | super(Transformer, self).__init__() |
| | if backend != "torch": |
| | raise ValueError( |
| | f"Backend '{backend}' is not supported, please use 'torch' instead" |
| | ) |
| | config_kwargs = config_args or {} |
| | model_kwargs = model_args or {} |
| | tokenizer_kwargs = tokenizer_args or {} |
| |
|
| | self.config = AutoConfig.from_pretrained( |
| | model_name_or_path, cache_dir=cache_dir, **config_kwargs |
| | ) |
| | self.default_task = model_args.pop("default_task", None) |
| | if self.default_task and self.default_task not in self.config.task_names: |
| | raise ValueError( |
| | f"Invalid task: {self.default_task}. Must be one of {self.config.task_names}." |
| | ) |
| |
|
| | self.model = AutoModel.from_pretrained( |
| | model_name_or_path, config=self.config, cache_dir=cache_dir, **model_kwargs |
| | ) |
| | self.processor = AutoProcessor.from_pretrained( |
| | model_name_or_path, |
| | cache_dir=cache_dir, |
| | use_fast=True, |
| | **tokenizer_kwargs, |
| | ) |
| | self.max_seq_length = max_seq_length or 8192 |
| |
|
| | def tokenize( |
| | self, texts: List[Union[str, Image.Image]], padding: Union[str, bool] = True |
| | ) -> Dict[str, torch.Tensor]: |
| | encoding = {} |
| | text_indices = [] |
| | image_indices = [] |
| | for i, text in enumerate(texts): |
| | if isinstance(text, str): |
| | |
| | clean_text = text |
| | if text.startswith("Query: "): |
| | clean_text = text[len("Query: ") :] |
| | elif text.startswith("Passage: "): |
| | clean_text = text[len("Passage: ") :] |
| |
|
| | if clean_text.startswith("http"): |
| | response = requests.get(clean_text) |
| | texts[i] = Image.open(BytesIO(response.content)).convert("RGB") |
| | image_indices.append(i) |
| | else: |
| | try: |
| | if Path(clean_text).is_file(): |
| | texts[i] = Image.open(clean_text).convert("RGB") |
| | image_indices.append(i) |
| | else: |
| | text_indices.append(i) |
| | except Exception as e: |
| | text_indices.append(i) |
| | elif isinstance(text, Image.Image): |
| | image_indices.append(i) |
| | else: |
| | raise ValueError(f"Invalid input type: {type(text)}") |
| | if text_indices: |
| | _texts = [texts[i] for i in text_indices] |
| | text_features = self.processor.process_texts( |
| | _texts, max_length=self.max_seq_length |
| | ) |
| | for key, value in text_features.items(): |
| | encoding[f"text_{key}"] = value |
| | encoding["text_indices"] = text_indices |
| |
|
| | if image_indices: |
| | _images = [texts[i] for i in image_indices] |
| | img_features = self.processor.process_images(_images) |
| | for key, value in img_features.items(): |
| | encoding[f"image_{key}"] = value |
| | encoding["image_indices"] = image_indices |
| |
|
| | return encoding |
| |
|
| | def forward( |
| | self, |
| | features: Dict[str, torch.Tensor], |
| | task: Optional[str] = None, |
| | truncate_dim: Optional[int] = None, |
| | ) -> Dict[str, torch.Tensor]: |
| | self.model.eval() |
| |
|
| | if task is None: |
| | if self.default_task is None: |
| | raise ValueError( |
| | "Task must be specified before encoding data. You can set it either during " |
| | "loading the model (e.g., model_kwargs={'default_task': 'retrieval'}) or " |
| | "pass it as an argument to the encode method (e.g., model.encode(texts, task='retrieval'))." |
| | ) |
| | task = self.default_task |
| | else: |
| | if task not in self.config.task_names: |
| | raise ValueError( |
| | f"Invalid task: {task}. Must be one of {self.config.task_names}." |
| | ) |
| |
|
| | device = self.model.device.type |
| | all_embeddings = [] |
| |
|
| | with torch.no_grad(): |
| | if any(k.startswith("text_") for k in features.keys()): |
| | text_batch = { |
| | k[len("text_") :]: v.to(device) |
| | for k, v in features.items() |
| | if k.startswith("text_") and k != "text_indices" |
| | } |
| | text_indices = features.get("text_indices", []) |
| | with torch.autocast(device_type=device, dtype=torch.bfloat16): |
| | text_embeddings = self.model( |
| | **text_batch, task_label=task |
| | ).single_vec_emb |
| | if truncate_dim: |
| | text_embeddings = text_embeddings[:, :truncate_dim] |
| | text_embeddings = torch.nn.functional.normalize( |
| | text_embeddings, p=2, dim=-1 |
| | ) |
| | for i, embedding in enumerate(text_embeddings): |
| | all_embeddings.append((text_indices[i], embedding)) |
| |
|
| | if any(k.startswith("image_") for k in features.keys()): |
| | image_batch = { |
| | k[len("image_") :]: v.to(device) |
| | for k, v in features.items() |
| | if k.startswith("image_") and k != "image_indices" |
| | } |
| | image_indices = features.get("image_indices", []) |
| |
|
| | with torch.autocast(device_type=device, dtype=torch.bfloat16): |
| | img_embeddings = self.model( |
| | **image_batch, task_label=task |
| | ).single_vec_emb |
| | if truncate_dim: |
| | img_embeddings = img_embeddings[:, :truncate_dim] |
| | img_embeddings = torch.nn.functional.normalize( |
| | img_embeddings, p=2, dim=-1 |
| | ) |
| |
|
| | for i, embedding in enumerate(img_embeddings): |
| | all_embeddings.append((image_indices[i], embedding)) |
| |
|
| | if not all_embeddings: |
| | raise RuntimeError("No embeddings were generated") |
| |
|
| | all_embeddings.sort(key=lambda x: x[0]) |
| | combined_embeddings = torch.stack([emb for _, emb in all_embeddings]) |
| | features["sentence_embedding"] = combined_embeddings |
| |
|
| | return features |
| |
|
| | @classmethod |
| | def load(cls, input_path: str) -> "Transformer": |
| | return cls(model_name_or_path=input_path) |
| |
|