import os from typing import Any import numpy as np import torch from sentence_transformers.models import InputModule from collections import OrderedDict from safetensors.torch import save_file as save_safetensors from luxical.embedder import Embedder, _pack_int_dict, _unpack_int_dict from luxical.sparse_to_dense_neural_nets import SparseToDenseEmbedder from luxical.tokenization import ArrowTokenizer class Transformer(InputModule): config_keys: list[str] = [] def __init__(self, embedder: Embedder, **kwargs): super().__init__() self.embedder = embedder def tokenize(self, texts: list[str], **kwargs) -> dict[str, torch.Tensor | Any]: return {"inputs": self.embedder.tokenize(texts)} def forward( self, features: dict[str, torch.Tensor], **kwargs ) -> dict[str, torch.Tensor]: tokenized_docs = features["inputs"] bow = self.embedder.bow_from_tokens(tokenized_docs) tfidf = self.embedder.tfidf_from_bow(bow) embeddings = self.embedder.bow_to_dense_embedder(tfidf) embeddings = torch.from_numpy(embeddings) features["sentence_embedding"] = embeddings return features def get_sentence_embedding_dimension(self) -> int: return self.embedder.embedding_dim @classmethod def load( cls, model_name_or_path: str, subfolder: str = "", token: bool | str | None = None, cache_folder: str | None = None, revision: str | None = None, local_files_only: bool = False, **kwargs, ): state_dict = cls.load_torch_weights( model_name_or_path, subfolder=subfolder, token=token, cache_folder=cache_folder, revision=revision, local_files_only=local_files_only, ) embedder = _embedder_from_state_dict(state_dict) return cls(embedder=embedder, **kwargs) def save(self, output_path, *args, safe_serialization=True, **kwargs) -> None: state_dict = _embedder_to_state_dict(self.embedder) if safe_serialization: save_safetensors(state_dict, os.path.join(output_path, "model.safetensors")) else: torch.save(state_dict, os.path.join(output_path, "pytorch_model.bin")) def _embedder_from_state_dict(state_dict: OrderedDict[str, torch.Tensor]) -> Embedder: version = int(state_dict["embedder.version"][0].item()) if version != 1: raise NotImplementedError(f"Unsupported embedder version: {version}") tok_bytes = bytes( state_dict["embedder.tokenizer"].cpu().numpy().astype(np.uint8).tolist() ) tokenizer = ArrowTokenizer(tok_bytes.decode("utf-8")) recognized_ngrams = ( state_dict["embedder.recognized_ngrams"] .cpu() .numpy() .astype(np.int64, copy=False) ) keys = state_dict["embedder.ngram_keys"].cpu().numpy().astype(np.int64, copy=False) vals = state_dict["embedder.ngram_vals"].cpu().numpy().astype(np.int64, copy=False) ngram_map = _pack_int_dict(keys, vals) idf_values = ( state_dict["embedder.idf_values"].cpu().numpy().astype(np.float32, copy=False) ) num_layers = int(state_dict["embedder.num_layers"][0].item()) layers = [ state_dict[f"embedder.nn_layer_{i}"] .cpu() .numpy() .astype(np.float32, copy=False) for i in range(num_layers) ] s2d = SparseToDenseEmbedder(layers=layers) return Embedder( tokenizer=tokenizer, recognized_ngrams=recognized_ngrams, ngram_hash_to_ngram_idx=ngram_map, idf_values=idf_values, bow_to_dense_embedder=s2d, ) def _embedder_to_state_dict(embedder: Embedder) -> OrderedDict[str, torch.Tensor]: sd: "OrderedDict[str, torch.Tensor]" = OrderedDict() # Version sd["embedder.version"] = torch.tensor([1], dtype=torch.long) # Tokenizer json bytes tok_bytes = np.frombuffer( embedder.tokenizer.to_str().encode("utf-8"), dtype=np.uint8 ) sd["embedder.tokenizer"] = torch.from_numpy(tok_bytes.copy()) # Recognized ngrams sd["embedder.recognized_ngrams"] = torch.from_numpy( embedder.recognized_ngrams.astype(np.int64, copy=False) ) # Hash map keys/values keys, vals = _unpack_int_dict(embedder.ngram_hash_to_ngram_idx) sd["embedder.ngram_keys"] = torch.from_numpy(keys.astype(np.int64, copy=False)) sd["embedder.ngram_vals"] = torch.from_numpy(vals.astype(np.int64, copy=False)) # IDF sd["embedder.idf_values"] = torch.from_numpy( embedder.idf_values.astype(np.float32, copy=False) ) # Layers layers = embedder.bow_to_dense_embedder.layers sd["embedder.num_layers"] = torch.tensor([len(layers)], dtype=torch.long) for i, layer in enumerate(layers): sd[f"embedder.nn_layer_{i}"] = torch.from_numpy( layer.astype(np.float32, copy=False) ) return sd