import gradio as gr import torch import torchaudio import numpy as np import os import pandas as pd from datetime import timedelta from pathlib import Path from transformers import ( Wav2Vec2ForCTC, Wav2Vec2Processor, WhisperProcessor, WhisperForConditionalGeneration ) from pyannote.audio import Pipeline, Model, Inference from scipy.spatial.distance import cdist import torchaudio.transforms as T # --- Optional Wylie→Tibetan converter (pyewts) --- try: from pyewts import pyewts _EWTSCONV = pyewts() except Exception: _EWTSCONV = None print("[WARN] pyewts not available. Wylie→Tibetan conversion will be skipped.") def ewts_to_unicode(text: str) -> str: if _EWTSCONV is None: return text try: return _EWTSCONV.toUnicode(text) except Exception: return text # ------------------- Audio Utils ------------------- def ensure_16k(waveform, sr, target_sr=16000): """Ensure waveform is 16kHz mono.""" if waveform.ndim > 1 and waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) # stereo -> mono if sr != target_sr: resampler = T.Resample(sr, target_sr) waveform = resampler(waveform) return waveform, target_sr # ------------------- Config ------------------- CACHE_DIR = "./models_cache" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"[INFO] Using device: {DEVICE}") HF_TOKEN = os.getenv("HF_TOKEN") # --- ASR model options (type, repo, meta) --- # meta can hold flags like { "wylie_output": True } to post-process Whisper output MODEL_OPTIONS = { # Wav2Vec2 / MMS (CTC) models "v7 (MMS Wav2Vec2)": ("ctc", "ganga4364/mms_300_Garchen_Rinpoche-v7-scracth-Checkpoint-23000", {}), "v6 (MMS Wav2Vec2)": ("ctc", "ganga4364/mms_300_Garchen_Rinpoche-v6-ft-Checkpoint-25000", {}), "v5 (MMS Wav2Vec2)": ("ctc", "ganga4364/mms_300_Garchen_Rinpoche-v5-base-Checkpoint-28000", {}), "v2 (MMS Wav2Vec2)": ("ctc", "openpecha/mms_300_Garchen_Rinpoche-v2-Checkpoint-43000", {}), "v4 (MMS Wav2Vec2)": ("ctc", "openpecha/mms_300_Garchen_Rinpoche-v4-Checkpoint-22000", {}), "v3 (MMS Wav2Vec2)": ("ctc", "openpecha/mms_300_Garchen_Rinpoche-v3-Checkpoint-25000", {}), "v1 (MMS Wav2Vec2)": ("ctc", "openpecha/Garchen_Rinpoche_stt", {}), "base (MMS Wav2Vec2)": ("ctc", "openpecha/general_stt_base_model", {}), # Whisper (seq2seq) models "Whisper (Wylie, default tokenizer)": ( "whisper", "ganga4364/whisper-small-tibetan-wylie-checkpoint-4000", {"wylie_output": True} # convert to Tibetan via pyewts ), "Whisper (Tibetan, added tokens)": ( "whisper", "ganga4364/whisper-small-latin-added-tibetan-checkpoint-4000", {"wylie_output": False} # already Tibetan script ), } # Cache for ASR models asr_cache = {} def load_asr_model(choice): """Load either a CTC (Wav2Vec2) or Whisper model + processor based on dropdown choice.""" if choice not in MODEL_OPTIONS: raise ValueError(f"Unknown model choice: {choice}") model_type, repo, meta = MODEL_OPTIONS[choice] if choice not in asr_cache: print(f"[INFO] Loading ASR model: {choice} ({model_type}) -> {repo}") if model_type == "ctc": model = Wav2Vec2ForCTC.from_pretrained(repo, cache_dir=CACHE_DIR).to(DEVICE) processor = Wav2Vec2Processor.from_pretrained(repo, cache_dir=CACHE_DIR) model.eval() elif model_type == "whisper": processor = WhisperProcessor.from_pretrained( repo, cache_dir=CACHE_DIR, language="Tibetan", task="transcribe" ) model = WhisperForConditionalGeneration.from_pretrained(repo, cache_dir=CACHE_DIR).to(DEVICE) model.eval() else: raise ValueError(f"Unsupported model type: {model_type}") asr_cache[choice] = (model_type, model, processor, meta) return asr_cache[choice] # ------------------- Whisper Large v3 (fallback for other speakers) ------------------- print("[INFO] Loading Whisper Large V3 for other speakers...") whisper_model_lg = WhisperForConditionalGeneration.from_pretrained( "openai/whisper-large-v3", cache_dir=CACHE_DIR ).to(DEVICE) whisper_proc_lg = WhisperProcessor.from_pretrained("openai/whisper-large-v3", cache_dir=CACHE_DIR) whisper_model_lg.eval() def transcribe_with_whisper_large(waveform, sr): waveform, sr = ensure_16k(waveform, sr) if waveform.shape[1] < 400: return "" inputs = whisper_proc_lg(waveform.squeeze(), sampling_rate=sr, return_tensors="pt") input_features = inputs["input_features"].to(DEVICE) forced_ids = whisper_proc_lg.get_decoder_prompt_ids(language="Tibetan", task="transcribe") with torch.no_grad(): pred_ids = whisper_model_lg.generate( input_features, forced_decoder_ids=forced_ids, num_beams=4, max_length=225 ) return whisper_proc_lg.batch_decode(pred_ids, skip_special_tokens=True)[0].strip() # ------------------- Pyannote ------------------- try: diarization_pipeline = Pipeline.from_pretrained( "pyannote/speaker-diarization-3.1", token=HF_TOKEN, cache_dir=CACHE_DIR ).to(DEVICE) print("Pyannote diarization loaded") except Exception as e: diarization_pipeline = None print(f"[WARN] Pyannote diarization not available: {e}") # Embedding model for voice print embedding_model = Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM", cache_dir=CACHE_DIR) embedding_inference = Inference(embedding_model, window="whole") # ------------------- Helpers ------------------- MAX_SEGMENT_SEC = 15 def format_timestamp(seconds, format_type="srt"): td = timedelta(seconds=seconds) hours, remainder = divmod(td.seconds, 3600) minutes, seconds = divmod(remainder, 60) milliseconds = round(td.microseconds / 1000) if format_type == "srt": return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}" else: return f"{hours:02d}:{minutes:02d}:{seconds:02d}.{milliseconds:03d}" def create_subtitle_file(timestamps_with_text, output_path, format_type="srt"): with open(output_path, "w", encoding="utf-8") as f: if format_type == "vtt": f.write("WEBVTT\n\n") for i, (start, end, text, speaker) in enumerate(timestamps_with_text, 1): if format_type == "srt": f.write(f"{i}\n") f.write(f"{format_timestamp(start)} --> {format_timestamp(end)}\n") f.write(f"{speaker}: {text}\n\n") else: f.write(f"{format_timestamp(start, 'vtt')} --> {format_timestamp(end, 'vtt')}\n") f.write(f"{text}\n\n") def split_long_segment(start, end, max_length=MAX_SEGMENT_SEC): segments = [] total_duration = end - start if total_duration <= max_length: return [(start, end)] current = start while current < end: seg_end = min(current + max_length, end) segments.append((current, seg_end)) current = seg_end return segments # --- CTC inference (Wav2Vec2/MMS) --- def transcribe_segment_ctc(waveform, sr, model, processor): waveform, sr = ensure_16k(waveform, sr) if waveform.shape[1] < 400: return "" inputs = processor(waveform.squeeze(), sampling_rate=sr, return_tensors="pt", padding=True) inputs = {k: v.to(DEVICE) for k, v in inputs.items()} with torch.no_grad(): logits = model(**inputs).logits pred_ids = torch.argmax(logits, dim=-1) return processor.decode(pred_ids[0].cpu()) # --- Whisper inference (seq2seq) with optional Wylie→Tibetan conversion --- def transcribe_segment_whisper(waveform, sr, model, processor, wylie_output: bool = False): waveform, sr = ensure_16k(waveform, sr) if waveform.shape[1] < 400: return "" inputs = processor(waveform.squeeze(), sampling_rate=sr, return_tensors="pt") input_features = inputs["input_features"].to(DEVICE) forced_ids = processor.get_decoder_prompt_ids(language="Tibetan", task="transcribe") with torch.no_grad(): pred_ids = model.generate( input_features, forced_decoder_ids=forced_ids, num_beams=4, max_length=225 ) text = processor.batch_decode(pred_ids, skip_special_tokens=True)[0].strip() # If this Whisper model outputs Wylie, convert to Tibetan Unicode if wylie_output: text = ewts_to_unicode(text) return text # ------------------- Speaker Identification ------------------- def identify_speaker(diarization_df, audio_path, voice_print_embedding, speaker_name, inference, threshold=0.6, n_segments=3): waveform, sr = torchaudio.load(audio_path) speaker_distances = {} for speaker in diarization_df['speaker'].unique(): sp_df = diarization_df[diarization_df['speaker'] == speaker].copy() sp_df['duration'] = sp_df['end'] - sp_df['start'] sp_df = sp_df.sort_values(by='duration', ascending=False).head(n_segments) distances = [] for _, row in sp_df.iterrows(): start, end = int(row['start']*sr), int(row['end']*sr) segment = waveform[:, start:end] seg_path = f"/tmp/{speaker}_{start}_{end}.wav" torchaudio.save(seg_path, segment, sr) try: seg_embedding = inference(seg_path) seg_embedding = np.atleast_2d(seg_embedding) dist = cdist(seg_embedding, voice_print_embedding, metric="cosine")[0, 0] distances.append(dist) except Exception as e: print(f"Error embedding segment {speaker} {row['start']}-{row['end']}: {e}") if distances: speaker_distances[speaker] = np.mean(distances) if not speaker_distances: return None, {}, diarization_df best_match = min(speaker_distances, key=speaker_distances.get) min_distance = speaker_distances[best_match] if min_distance <= threshold: mapping = {sp: speaker_name if sp == best_match else f"Other Speaker {i}" for i, sp in enumerate(speaker_distances.keys())} else: mapping = {sp: f"Speaker {i}" for i, sp in enumerate(speaker_distances.keys())} diarization_df['identified_speaker'] = diarization_df['speaker'].map(mapping) return best_match, mapping, diarization_df # ------------------- Main ------------------- def process_audio(model_choice, mode, voice_print_path, audio_path, speaker_name, threshold=0.6): # Load full audio waveform, sample_rate = torchaudio.load(audio_path) waveform, sample_rate = ensure_16k(waveform, sample_rate) # Load selected ASR (CTC or Whisper) + meta flags model_type, asr_model, asr_processor, meta = load_asr_model(model_choice) wylie_output = bool(meta.get("wylie_output", False)) # Voice print vp_waveform, vp_sr = torchaudio.load(voice_print_path) vp_waveform, vp_sr = ensure_16k(vp_waveform, vp_sr) tmp_vp = "/tmp/voice_print_16k.wav" torchaudio.save(tmp_vp, vp_waveform, vp_sr) voice_print_embedding = embedding_inference(tmp_vp) voice_print_embedding = np.atleast_2d(voice_print_embedding) results = [] if "Diarization" in mode: if diarization_pipeline is None: return "Pyannote diarization is not available.", None, None diarization = diarization_pipeline({"waveform": waveform, "sample_rate": sample_rate}) # Run diarization - pass audio file path directly for better compatibility #diarization = diarization_pipeline(audio_path) # Correct API for pyannote 3.1+ with DiarizeOutput data = [] # Check if we have the new API (DiarizeOutput with speaker_diarization attribute) if hasattr(diarization, 'speaker_diarization'): # New API (pyannote 3.1+) - iterate over speaker_diarization for turn, speaker in diarization.speaker_diarization: data.append({ "start": turn.start, "end": turn.end, "speaker": speaker }) elif hasattr(diarization, 'itertracks'): # Old API (pyannote < 3.1) - Annotation object for segment, track, speaker in diarization.itertracks(yield_label=True): data.append({ "start": segment.start, "end": segment.end, "speaker": speaker }) else: return "Unsupported pyannote.audio version. Please check the diarization output format.", None, None if not data: return "No speaker segments found in diarization.", None, None diarization_df = pd.DataFrame(data) # Identify target speaker _, mapping, diarization_df = identify_speaker( diarization_df, audio_path, voice_print_embedding, speaker_name, embedding_inference, threshold ) for _, row in diarization_df.iterrows(): for seg_start, seg_end in split_long_segment(row['start'], row['end']): seg_waveform = waveform[:, int(seg_start*sample_rate):int(seg_end*sample_rate)] if row['identified_speaker'] == speaker_name: # Target speaker -> use selected ASR path if model_type == "ctc": transcription = transcribe_segment_ctc(seg_waveform, sample_rate, asr_model, asr_processor) else: # whisper transcription = transcribe_segment_whisper( seg_waveform, sample_rate, asr_model, asr_processor, wylie_output=wylie_output ) else: if mode == "Diarization (Target Speaker Only)": transcription = "" # skip other speakers else: # Other speakers -> Whisper Large v3 fallback (already Tibetan) transcription = transcribe_with_whisper_large(seg_waveform, sample_rate) results.append((seg_start, seg_end, transcription, row['identified_speaker'])) # Save subtitle files base_path = os.path.splitext(audio_path)[0] srt_path = f"{base_path}_identified.srt" vtt_path = f"{base_path}_identified.vtt" create_subtitle_file(results, srt_path, "srt") create_subtitle_file(results, vtt_path, "vtt") transcript_text = "\n".join([f"{sp}: {txt}" for (_, _, txt, sp) in results]) return transcript_text, srt_path, vtt_path # ------------------- Gradio ------------------- demo = gr.Interface( fn=process_audio, inputs=[ gr.Dropdown( choices=list(MODEL_OPTIONS.keys()), value="base (MMS Wav2Vec2)", label="Select ASR Model" ), gr.Radio( choices=["Diarization (Transcribe All)", "Diarization (Target Speaker Only)"], value="Diarization (Transcribe All)", label="Segmentation Method" ), gr.Audio(sources=["upload"], type="filepath", label="Voice Print Audio"), gr.Audio(sources=["upload"], type="filepath", label="Full Audio"), gr.Textbox(value="GR", label="Speaker Name for Voice Print") ], outputs=[ gr.Textbox( label="Transcript", lines=24, # height (try 20–30) max_lines=60, # optional scroll cap show_copy_button=True ), gr.File(label="SRT File"), gr.File(label="WebVTT File") ], title="STT + Speaker Identification", description=( "Choose an ASR model (MMS Wav2Vec2 or your fine-tuned Whisper). " "If you pick 'Whisper (Wylie, default tokenizer)', the output will be converted from Wylie to Tibetan Unicode via pyewts. " "Target speaker → chosen model; other speakers → Whisper Large v3." ) ) if __name__ == "__main__": demo.launch(share=True)