Spaces:
Runtime error
Runtime error
| from transformers import MT5ForConditionalGeneration, MT5Tokenizer | |
| from transformers import AutoTokenizer | |
| import re | |
| class PersianTextProcessor: | |
| """ | |
| A class for processing Persian text. | |
| Attributes: | |
| model_size (str): The size of the MT5 model. | |
| model_name (str): The name of the MT5 model. | |
| tokenizer (MT5Tokenizer): The MT5 tokenizer. | |
| model (MT5ForConditionalGeneration): The MT5 model. | |
| Methods: | |
| clean_persian_text(text): Cleans the given Persian text. | |
| translate_text(persian_text): Translates the given Persian text to English. | |
| """ | |
| def __init__(self, model_size="small"): | |
| """ | |
| Initializes the PersianTextProcessor class. | |
| Args: | |
| model_size (str): The size of the MT5 model. | |
| """ | |
| self.model_size = model_size | |
| self.model_name = f"persiannlp/mt5-{self.model_size}-parsinlu-opus-translation_fa_en" | |
| self.tokenizer =MT5Tokenizer.from_pretrained(self.model_name) #AutoTokenizer.from_pretrained("persiannlp/mt5-small-parsinlu-opus-translation_fa_en") | |
| self.model = MT5ForConditionalGeneration.from_pretrained(self.model_name) | |
| def clean_persian_text(self, text): | |
| """ | |
| Cleans the given Persian text by removing emojis, specific patterns, and replacing special characters. | |
| Args: | |
| text (str): The input Persian text. | |
| Returns: | |
| str: The cleaned Persian text. | |
| """ | |
| # Create a regular expression to match emojis. | |
| emoji_pattern = re.compile( | |
| "[" | |
| "\U0001F600-\U0001F64F" # emoticons | |
| "\U0001F300-\U0001F5FF" # symbols & pictographs | |
| "\U0001F680-\U0001F6FF" # transport & map symbols | |
| "\U0001F1E0-\U0001F1FF" # flags (iOS) | |
| "]+", | |
| flags=re.UNICODE, | |
| ) | |
| # Create a regular expression to match specific patterns. | |
| pattern = "[\U0001F90D\U00002764\U0001F91F][\U0000FE0F\U0000200D]*" | |
| # Remove emojis, specific patterns, and special characters from the text. | |
| text = emoji_pattern.sub("", text) | |
| text = re.sub(pattern, "", text) | |
| text = text.replace("✌", "") | |
| text = text.replace("@", "") | |
| text = text.replace("#", "hashtag_") | |
| return text | |
| def run_model(self, input_string, **generator_args): | |
| """ | |
| Runs the MT5 model on the given input string. | |
| Args: | |
| input_string (str): The input string. | |
| **generator_args: Additional arguments to pass to the MT5 model. | |
| Returns: | |
| str: The output of the MT5 model. | |
| """ | |
| # Encode the input string as a sequence of tokens. | |
| input_ids = self.tokenizer.encode(input_string, return_tensors="pt") | |
| # Generate the output text. | |
| res = self.model.generate(input_ids, **generator_args) | |
| # Decode the output text to a string. | |
| output = self.tokenizer.batch_decode(res, skip_special_tokens=True) | |
| return output | |
| def translate_text(self, persian_text): | |
| """ | |
| Translates the given Persian text to English. | |
| Args: | |
| persian_text (str): The Persian text to translate. | |
| Returns: | |
| str: The translated text. | |
| """ | |
| # Clean the Persian text. | |
| text_cleaned = self.clean_persian_text(persian_text) | |
| # Translate the cleaned text. | |
| translated_text = self.run_model(input_string=text_cleaned) | |
| return translated_text | |