canary-qwen-2.5b / handler.py
manueljohnson063's picture
Update handler.py
28e37b7 verified
import torch
import soundfile as sf
from nemo.collections.speechlm2.models import SALM
class EndpointHandler():
def __init__(self, path=""):
# Load the model from Hugging Face directory (ASR + LLM model)
self.model = SALM.from_pretrained(path)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
def __call__(self, data):
mode = data.get("mode")
if mode == "asr":
# For ASR, we expect `audio` to be a list of floats or bytes
audio = data.get("audio") # base64 / list of floats / bytes
if isinstance(audio, list):
audio_tensor = torch.tensor(audio).unsqueeze(0).to(self.device)
else:
# fallback: assume bytes
audio_data, sr = sf.read(audio)
audio_tensor = torch.tensor(audio_data).unsqueeze(0).to(self.device)
transcripts = self.model.transcribe(audio_tensor)
return {"text": transcripts}
elif mode == "llm":
# For LLM, we expect "inputs" key
prompt = data.get("inputs", "")
outputs = self.model.generate([prompt])
return {"text": outputs[0]}
else:
return {"error": "Please specify mode as 'asr' or 'llm'."}