Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, HTTPException | |
| from typing import List | |
| from pydantic import BaseModel | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| from IndicTransToolkit import IndicProcessor | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import torch | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| model = AutoModelForSeq2SeqLM.from_pretrained( | |
| "ai4bharat/indictrans2-indic-indic-1B", trust_remote_code=True | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "ai4bharat/indictrans2-indic-indic-1B", trust_remote_code=True | |
| ) | |
| ip = IndicProcessor(inference=True) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(DEVICE) | |
| def translate_text(sentences: List[str], src_lang: str, target_lang: str): | |
| try: | |
| batch = ip.preprocess_batch(sentences, src_lang=src_lang, tgt_lang=target_lang) | |
| inputs = tokenizer( | |
| batch, | |
| truncation=True, | |
| padding="longest", | |
| return_tensors="pt", | |
| return_attention_mask=True, | |
| ).to(DEVICE) | |
| with torch.no_grad(): | |
| generated_tokens = model.generate( | |
| **inputs, | |
| use_cache=True, | |
| min_length=0, | |
| max_length=256, | |
| num_beams=5, | |
| num_return_sequences=1, | |
| ) | |
| with tokenizer.as_target_tokenizer(): | |
| generated_tokens = tokenizer.batch_decode( | |
| generated_tokens.detach().cpu().tolist(), | |
| skip_special_tokens=True, | |
| ) | |
| return generated_tokens | |
| except Exception as e: | |
| return str(e) | |
| def read_root(): | |
| return {"Hello": "World"} | |
| class TranslateRequest(BaseModel): | |
| sentences: List[str] | |
| src_lang: str | |
| target_lang: str | |
| def translate(request: TranslateRequest): | |
| try: | |
| result = translate_text(request.sentences, request.src_lang, request.target_lang) | |
| return result | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |