| | import torch |
| | import soundfile as sf |
| | from nemo.collections.speechlm2.models import SALM |
| |
|
| | class EndpointHandler(): |
| | def __init__(self, path=""): |
| | |
| | self.model = SALM.from_pretrained(path) |
| | self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| | self.model.to(self.device) |
| |
|
| | def __call__(self, data): |
| | mode = data.get("mode") |
| | |
| | if mode == "asr": |
| | |
| | audio = data.get("audio") |
| |
|
| | if isinstance(audio, list): |
| | audio_tensor = torch.tensor(audio).unsqueeze(0).to(self.device) |
| | else: |
| | |
| | audio_data, sr = sf.read(audio) |
| | audio_tensor = torch.tensor(audio_data).unsqueeze(0).to(self.device) |
| |
|
| | transcripts = self.model.transcribe(audio_tensor) |
| | return {"text": transcripts} |
| |
|
| | elif mode == "llm": |
| | |
| | prompt = data.get("inputs", "") |
| | outputs = self.model.generate([prompt]) |
| | return {"text": outputs[0]} |
| | |
| | else: |
| | return {"error": "Please specify mode as 'asr' or 'llm'."} |
| |
|