luxical-one / luxical_st_wrapper.py
lukemerrick's picture
Add Sentence Transformers integration (#1)
474cfeb verified
import os
from typing import Any
import numpy as np
import torch
from sentence_transformers.models import InputModule
from collections import OrderedDict
from safetensors.torch import save_file as save_safetensors
from luxical.embedder import Embedder, _pack_int_dict, _unpack_int_dict
from luxical.sparse_to_dense_neural_nets import SparseToDenseEmbedder
from luxical.tokenization import ArrowTokenizer
class Transformer(InputModule):
config_keys: list[str] = []
def __init__(self, embedder: Embedder, **kwargs):
super().__init__()
self.embedder = embedder
def tokenize(self, texts: list[str], **kwargs) -> dict[str, torch.Tensor | Any]:
return {"inputs": self.embedder.tokenize(texts)}
def forward(
self, features: dict[str, torch.Tensor], **kwargs
) -> dict[str, torch.Tensor]:
tokenized_docs = features["inputs"]
bow = self.embedder.bow_from_tokens(tokenized_docs)
tfidf = self.embedder.tfidf_from_bow(bow)
embeddings = self.embedder.bow_to_dense_embedder(tfidf)
embeddings = torch.from_numpy(embeddings)
features["sentence_embedding"] = embeddings
return features
def get_sentence_embedding_dimension(self) -> int:
return self.embedder.embedding_dim
@classmethod
def load(
cls,
model_name_or_path: str,
subfolder: str = "",
token: bool | str | None = None,
cache_folder: str | None = None,
revision: str | None = None,
local_files_only: bool = False,
**kwargs,
):
state_dict = cls.load_torch_weights(
model_name_or_path,
subfolder=subfolder,
token=token,
cache_folder=cache_folder,
revision=revision,
local_files_only=local_files_only,
)
embedder = _embedder_from_state_dict(state_dict)
return cls(embedder=embedder, **kwargs)
def save(self, output_path, *args, safe_serialization=True, **kwargs) -> None:
state_dict = _embedder_to_state_dict(self.embedder)
if safe_serialization:
save_safetensors(state_dict, os.path.join(output_path, "model.safetensors"))
else:
torch.save(state_dict, os.path.join(output_path, "pytorch_model.bin"))
def _embedder_from_state_dict(state_dict: OrderedDict[str, torch.Tensor]) -> Embedder:
version = int(state_dict["embedder.version"][0].item())
if version != 1:
raise NotImplementedError(f"Unsupported embedder version: {version}")
tok_bytes = bytes(
state_dict["embedder.tokenizer"].cpu().numpy().astype(np.uint8).tolist()
)
tokenizer = ArrowTokenizer(tok_bytes.decode("utf-8"))
recognized_ngrams = (
state_dict["embedder.recognized_ngrams"]
.cpu()
.numpy()
.astype(np.int64, copy=False)
)
keys = state_dict["embedder.ngram_keys"].cpu().numpy().astype(np.int64, copy=False)
vals = state_dict["embedder.ngram_vals"].cpu().numpy().astype(np.int64, copy=False)
ngram_map = _pack_int_dict(keys, vals)
idf_values = (
state_dict["embedder.idf_values"].cpu().numpy().astype(np.float32, copy=False)
)
num_layers = int(state_dict["embedder.num_layers"][0].item())
layers = [
state_dict[f"embedder.nn_layer_{i}"]
.cpu()
.numpy()
.astype(np.float32, copy=False)
for i in range(num_layers)
]
s2d = SparseToDenseEmbedder(layers=layers)
return Embedder(
tokenizer=tokenizer,
recognized_ngrams=recognized_ngrams,
ngram_hash_to_ngram_idx=ngram_map,
idf_values=idf_values,
bow_to_dense_embedder=s2d,
)
def _embedder_to_state_dict(embedder: Embedder) -> OrderedDict[str, torch.Tensor]:
sd: "OrderedDict[str, torch.Tensor]" = OrderedDict()
# Version
sd["embedder.version"] = torch.tensor([1], dtype=torch.long)
# Tokenizer json bytes
tok_bytes = np.frombuffer(
embedder.tokenizer.to_str().encode("utf-8"), dtype=np.uint8
)
sd["embedder.tokenizer"] = torch.from_numpy(tok_bytes.copy())
# Recognized ngrams
sd["embedder.recognized_ngrams"] = torch.from_numpy(
embedder.recognized_ngrams.astype(np.int64, copy=False)
)
# Hash map keys/values
keys, vals = _unpack_int_dict(embedder.ngram_hash_to_ngram_idx)
sd["embedder.ngram_keys"] = torch.from_numpy(keys.astype(np.int64, copy=False))
sd["embedder.ngram_vals"] = torch.from_numpy(vals.astype(np.int64, copy=False))
# IDF
sd["embedder.idf_values"] = torch.from_numpy(
embedder.idf_values.astype(np.float32, copy=False)
)
# Layers
layers = embedder.bow_to_dense_embedder.layers
sd["embedder.num_layers"] = torch.tensor([len(layers)], dtype=torch.long)
for i, layer in enumerate(layers):
sd[f"embedder.nn_layer_{i}"] = torch.from_numpy(
layer.astype(np.float32, copy=False)
)
return sd