| | """Tokenization classes for ProteinGLM.""" |
| |
|
| | import os |
| | from typing import List, Optional, Union, Dict, Any |
| | from torch import TensorType |
| | from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast |
| | from transformers.tokenization_utils_base import EncodedInput, BatchEncoding |
| |
|
| | VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} |
| |
|
| |
|
| | def load_vocab_file(vocab_file: str) -> List[str]: |
| | with open(vocab_file, "r") as f: |
| | lines = f.read().splitlines() |
| | return [line.strip() for line in lines] |
| |
|
| |
|
| | class ProteinGLMTokenizer(PreTrainedTokenizer): |
| | """ |
| | Constructs a ProteinGLM tokenizer. |
| | """ |
| |
|
| | vocab_files_names = VOCAB_FILES_NAMES |
| | model_input_names = ["input_ids", "attention_mask", "position_ids"] |
| | def __init__( |
| | self, |
| | vocab_file: str, |
| | unk_token: str = "<unk>", |
| | pad_token: str = "<pad>", |
| | mask_token: str = "<mask>", |
| | eos_token: str = "<eos>", |
| | model_max_length: int = 2048, |
| | additional_special_tokens: Optional[List[str]] = None, |
| | **kwargs, |
| | ): |
| | self.all_tokens = load_vocab_file(vocab_file) |
| | self._id_to_token = dict(enumerate(self.all_tokens)) |
| | self._token_to_id = {tok: ind for ind, tok in enumerate(self.all_tokens)} |
| |
|
| | if additional_special_tokens is None: |
| | additional_special_tokens = ['<pad>', '<mask>', '<gmask>', '<smask>', '<eod>', '<sop>', '<eop>', '<eos>', '<unk>'] |
| |
|
| | super().__init__( |
| | unk_token=unk_token, |
| | pad_token=pad_token, |
| | mask_token=mask_token, |
| | eos_token=eos_token, |
| | model_max_length=model_max_length, |
| | additional_special_tokens=additional_special_tokens, |
| | **kwargs, |
| | ) |
| |
|
| | self.unique_no_split_tokens = self.all_tokens |
| | self._update_trie(self.unique_no_split_tokens) |
| |
|
| | def _convert_id_to_token(self, index: int) -> str: |
| | return self._id_to_token.get(index, self.unk_token) |
| |
|
| | def _convert_token_to_id(self, token: str) -> int: |
| | return self._token_to_id.get(token, self._token_to_id.get(self.unk_token)) |
| |
|
| | def _tokenize(self, text: str, **kwargs) -> List[str]: |
| | return text.split() |
| |
|
| | def get_vocab(self) -> dict: |
| | base_vocab = self._token_to_id.copy() |
| | base_vocab.update(self.added_tokens_encoder) |
| | return base_vocab |
| |
|
| | def token_to_id(self, token: str) -> int: |
| | return self._token_to_id.get(token, self._token_to_id.get(self.unk_token)) |
| |
|
| | def id_to_token(self, index: int) -> str: |
| | return self._id_to_token.get(index, self.unk_token) |
| |
|
| | def build_inputs_with_special_tokens( |
| | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None |
| | ) -> List[int]: |
| | sep = [self.eos_token_id] |
| | if token_ids_1 is None: |
| | if self.eos_token_id is None: |
| | return token_ids_0 |
| | else: |
| | return token_ids_0 + sep |
| | elif self.eos_token_id is None: |
| | raise ValueError("Cannot tokenize multiple sequences when EOS token is not set!") |
| | return token_ids_0 + sep + token_ids_1 + sep |
| |
|
| |
|
| | def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple: |
| | vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "tokenizer.model") |
| | with open(vocab_file, "w") as f: |
| | f.write("\n".join(self.all_tokens)) |
| | return (vocab_file,) |
| |
|
| | @property |
| | def vocab_size(self) -> int: |
| | return len(self.all_tokens) |
| |
|
| | def apply_chat_template( |
| | self, |
| | query, |
| | add_generation_prompt: bool = True, |
| | tokenize: bool = True, |
| | padding: bool = False, |
| | truncation: bool = False, |
| | max_length: Optional[int] = None, |
| | return_tensors: Optional[Union[str, TensorType]] = None, |
| | return_dict: bool = False, |
| | tokenizer_kwargs: Optional[Dict[str, Any]] = None, |
| | add_special_tokens: bool = True, |
| | **kwargs, |
| | ) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]: |
| |
|
| | generation_prompt = "<gmask><sop><eos>" |
| | if isinstance(query, str): |
| | query = [query] |
| | prompt_query = [] |
| | if add_generation_prompt: |
| | for each in query: |
| | assert isinstance(each, str) |
| | prompt_query.append(generation_prompt+each) |
| | else: |
| | prompt_query = query |
| | if tokenize: |
| | output = self.batch_encode_plus( |
| | prompt_query, |
| | padding=padding, |
| | truncation=truncation, |
| | max_length=max_length, |
| | return_tensors=return_tensors, |
| | is_split_into_words=True, |
| | add_special_tokens=False |
| | ) |
| | if return_dict: |
| | return output |
| | else: |
| | return output["input_ids"] |
| | else: |
| | return prompt_query |