Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| import os | |
| import pandas as pd | |
| from datetime import timedelta | |
| from pathlib import Path | |
| from transformers import ( | |
| Wav2Vec2ForCTC, | |
| Wav2Vec2Processor, | |
| WhisperProcessor, | |
| WhisperForConditionalGeneration | |
| ) | |
| from pyannote.audio import Pipeline, Model, Inference | |
| from scipy.spatial.distance import cdist | |
| import torchaudio.transforms as T | |
| # --- Optional Wylie→Tibetan converter (pyewts) --- | |
| try: | |
| from pyewts import pyewts | |
| _EWTSCONV = pyewts() | |
| except Exception: | |
| _EWTSCONV = None | |
| print("[WARN] pyewts not available. Wylie→Tibetan conversion will be skipped.") | |
| def ewts_to_unicode(text: str) -> str: | |
| if _EWTSCONV is None: | |
| return text | |
| try: | |
| return _EWTSCONV.toUnicode(text) | |
| except Exception: | |
| return text | |
| # ------------------- Audio Utils ------------------- | |
| def ensure_16k(waveform, sr, target_sr=16000): | |
| """Ensure waveform is 16kHz mono.""" | |
| if waveform.ndim > 1 and waveform.shape[0] > 1: | |
| waveform = waveform.mean(dim=0, keepdim=True) # stereo -> mono | |
| if sr != target_sr: | |
| resampler = T.Resample(sr, target_sr) | |
| waveform = resampler(waveform) | |
| return waveform, target_sr | |
| # ------------------- Config ------------------- | |
| CACHE_DIR = "./models_cache" | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"[INFO] Using device: {DEVICE}") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| # --- ASR model options (type, repo, meta) --- | |
| # meta can hold flags like { "wylie_output": True } to post-process Whisper output | |
| MODEL_OPTIONS = { | |
| # Wav2Vec2 / MMS (CTC) models | |
| "v7 (MMS Wav2Vec2)": ("ctc", "ganga4364/mms_300_Garchen_Rinpoche-v7-scracth-Checkpoint-23000", {}), | |
| "v6 (MMS Wav2Vec2)": ("ctc", "ganga4364/mms_300_Garchen_Rinpoche-v6-ft-Checkpoint-25000", {}), | |
| "v5 (MMS Wav2Vec2)": ("ctc", "ganga4364/mms_300_Garchen_Rinpoche-v5-base-Checkpoint-28000", {}), | |
| "v2 (MMS Wav2Vec2)": ("ctc", "openpecha/mms_300_Garchen_Rinpoche-v2-Checkpoint-43000", {}), | |
| "v4 (MMS Wav2Vec2)": ("ctc", "openpecha/mms_300_Garchen_Rinpoche-v4-Checkpoint-22000", {}), | |
| "v3 (MMS Wav2Vec2)": ("ctc", "openpecha/mms_300_Garchen_Rinpoche-v3-Checkpoint-25000", {}), | |
| "v1 (MMS Wav2Vec2)": ("ctc", "openpecha/Garchen_Rinpoche_stt", {}), | |
| "base (MMS Wav2Vec2)": ("ctc", "openpecha/general_stt_base_model", {}), | |
| # Whisper (seq2seq) models | |
| "Whisper (Wylie, default tokenizer)": ( | |
| "whisper", | |
| "ganga4364/whisper-small-tibetan-wylie-checkpoint-4000", | |
| {"wylie_output": True} # convert to Tibetan via pyewts | |
| ), | |
| "Whisper (Tibetan, added tokens)": ( | |
| "whisper", | |
| "ganga4364/whisper-small-latin-added-tibetan-checkpoint-4000", | |
| {"wylie_output": False} # already Tibetan script | |
| ), | |
| } | |
| # Cache for ASR models | |
| asr_cache = {} | |
| def load_asr_model(choice): | |
| """Load either a CTC (Wav2Vec2) or Whisper model + processor based on dropdown choice.""" | |
| if choice not in MODEL_OPTIONS: | |
| raise ValueError(f"Unknown model choice: {choice}") | |
| model_type, repo, meta = MODEL_OPTIONS[choice] | |
| if choice not in asr_cache: | |
| print(f"[INFO] Loading ASR model: {choice} ({model_type}) -> {repo}") | |
| if model_type == "ctc": | |
| model = Wav2Vec2ForCTC.from_pretrained(repo, cache_dir=CACHE_DIR).to(DEVICE) | |
| processor = Wav2Vec2Processor.from_pretrained(repo, cache_dir=CACHE_DIR) | |
| model.eval() | |
| elif model_type == "whisper": | |
| processor = WhisperProcessor.from_pretrained( | |
| repo, cache_dir=CACHE_DIR, language="Tibetan", task="transcribe" | |
| ) | |
| model = WhisperForConditionalGeneration.from_pretrained(repo, cache_dir=CACHE_DIR).to(DEVICE) | |
| model.eval() | |
| else: | |
| raise ValueError(f"Unsupported model type: {model_type}") | |
| asr_cache[choice] = (model_type, model, processor, meta) | |
| return asr_cache[choice] | |
| # ------------------- Whisper Large v3 (fallback for other speakers) ------------------- | |
| print("[INFO] Loading Whisper Large V3 for other speakers...") | |
| whisper_model_lg = WhisperForConditionalGeneration.from_pretrained( | |
| "openai/whisper-large-v3", cache_dir=CACHE_DIR | |
| ).to(DEVICE) | |
| whisper_proc_lg = WhisperProcessor.from_pretrained("openai/whisper-large-v3", cache_dir=CACHE_DIR) | |
| whisper_model_lg.eval() | |
| def transcribe_with_whisper_large(waveform, sr): | |
| waveform, sr = ensure_16k(waveform, sr) | |
| if waveform.shape[1] < 400: | |
| return "" | |
| inputs = whisper_proc_lg(waveform.squeeze(), sampling_rate=sr, return_tensors="pt") | |
| input_features = inputs["input_features"].to(DEVICE) | |
| forced_ids = whisper_proc_lg.get_decoder_prompt_ids(language="Tibetan", task="transcribe") | |
| with torch.no_grad(): | |
| pred_ids = whisper_model_lg.generate( | |
| input_features, forced_decoder_ids=forced_ids, num_beams=4, max_length=225 | |
| ) | |
| return whisper_proc_lg.batch_decode(pred_ids, skip_special_tokens=True)[0].strip() | |
| # ------------------- Pyannote ------------------- | |
| try: | |
| diarization_pipeline = Pipeline.from_pretrained( | |
| "pyannote/speaker-diarization-3.1", token=HF_TOKEN, cache_dir=CACHE_DIR | |
| ).to(DEVICE) | |
| print("Pyannote diarization loaded") | |
| except Exception as e: | |
| diarization_pipeline = None | |
| print(f"[WARN] Pyannote diarization not available: {e}") | |
| # Embedding model for voice print | |
| embedding_model = Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM", cache_dir=CACHE_DIR) | |
| embedding_inference = Inference(embedding_model, window="whole") | |
| # ------------------- Helpers ------------------- | |
| MAX_SEGMENT_SEC = 15 | |
| def format_timestamp(seconds, format_type="srt"): | |
| td = timedelta(seconds=seconds) | |
| hours, remainder = divmod(td.seconds, 3600) | |
| minutes, seconds = divmod(remainder, 60) | |
| milliseconds = round(td.microseconds / 1000) | |
| if format_type == "srt": | |
| return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}" | |
| else: | |
| return f"{hours:02d}:{minutes:02d}:{seconds:02d}.{milliseconds:03d}" | |
| def create_subtitle_file(timestamps_with_text, output_path, format_type="srt"): | |
| with open(output_path, "w", encoding="utf-8") as f: | |
| if format_type == "vtt": | |
| f.write("WEBVTT\n\n") | |
| for i, (start, end, text, speaker) in enumerate(timestamps_with_text, 1): | |
| if format_type == "srt": | |
| f.write(f"{i}\n") | |
| f.write(f"{format_timestamp(start)} --> {format_timestamp(end)}\n") | |
| f.write(f"{speaker}: {text}\n\n") | |
| else: | |
| f.write(f"{format_timestamp(start, 'vtt')} --> {format_timestamp(end, 'vtt')}\n") | |
| f.write(f"<v {speaker}>{text}\n\n") | |
| def split_long_segment(start, end, max_length=MAX_SEGMENT_SEC): | |
| segments = [] | |
| total_duration = end - start | |
| if total_duration <= max_length: | |
| return [(start, end)] | |
| current = start | |
| while current < end: | |
| seg_end = min(current + max_length, end) | |
| segments.append((current, seg_end)) | |
| current = seg_end | |
| return segments | |
| # --- CTC inference (Wav2Vec2/MMS) --- | |
| def transcribe_segment_ctc(waveform, sr, model, processor): | |
| waveform, sr = ensure_16k(waveform, sr) | |
| if waveform.shape[1] < 400: | |
| return "" | |
| inputs = processor(waveform.squeeze(), sampling_rate=sr, return_tensors="pt", padding=True) | |
| inputs = {k: v.to(DEVICE) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| pred_ids = torch.argmax(logits, dim=-1) | |
| return processor.decode(pred_ids[0].cpu()) | |
| # --- Whisper inference (seq2seq) with optional Wylie→Tibetan conversion --- | |
| def transcribe_segment_whisper(waveform, sr, model, processor, wylie_output: bool = False): | |
| waveform, sr = ensure_16k(waveform, sr) | |
| if waveform.shape[1] < 400: | |
| return "" | |
| inputs = processor(waveform.squeeze(), sampling_rate=sr, return_tensors="pt") | |
| input_features = inputs["input_features"].to(DEVICE) | |
| forced_ids = processor.get_decoder_prompt_ids(language="Tibetan", task="transcribe") | |
| with torch.no_grad(): | |
| pred_ids = model.generate( | |
| input_features, | |
| forced_decoder_ids=forced_ids, | |
| num_beams=4, | |
| max_length=225 | |
| ) | |
| text = processor.batch_decode(pred_ids, skip_special_tokens=True)[0].strip() | |
| # If this Whisper model outputs Wylie, convert to Tibetan Unicode | |
| if wylie_output: | |
| text = ewts_to_unicode(text) | |
| return text | |
| # ------------------- Speaker Identification ------------------- | |
| def identify_speaker(diarization_df, audio_path, voice_print_embedding, speaker_name, inference, threshold=0.6, n_segments=3): | |
| waveform, sr = torchaudio.load(audio_path) | |
| speaker_distances = {} | |
| for speaker in diarization_df['speaker'].unique(): | |
| sp_df = diarization_df[diarization_df['speaker'] == speaker].copy() | |
| sp_df['duration'] = sp_df['end'] - sp_df['start'] | |
| sp_df = sp_df.sort_values(by='duration', ascending=False).head(n_segments) | |
| distances = [] | |
| for _, row in sp_df.iterrows(): | |
| start, end = int(row['start']*sr), int(row['end']*sr) | |
| segment = waveform[:, start:end] | |
| seg_path = f"/tmp/{speaker}_{start}_{end}.wav" | |
| torchaudio.save(seg_path, segment, sr) | |
| try: | |
| seg_embedding = inference(seg_path) | |
| seg_embedding = np.atleast_2d(seg_embedding) | |
| dist = cdist(seg_embedding, voice_print_embedding, metric="cosine")[0, 0] | |
| distances.append(dist) | |
| except Exception as e: | |
| print(f"Error embedding segment {speaker} {row['start']}-{row['end']}: {e}") | |
| if distances: | |
| speaker_distances[speaker] = np.mean(distances) | |
| if not speaker_distances: | |
| return None, {}, diarization_df | |
| best_match = min(speaker_distances, key=speaker_distances.get) | |
| min_distance = speaker_distances[best_match] | |
| if min_distance <= threshold: | |
| mapping = {sp: speaker_name if sp == best_match else f"Other Speaker {i}" | |
| for i, sp in enumerate(speaker_distances.keys())} | |
| else: | |
| mapping = {sp: f"Speaker {i}" for i, sp in enumerate(speaker_distances.keys())} | |
| diarization_df['identified_speaker'] = diarization_df['speaker'].map(mapping) | |
| return best_match, mapping, diarization_df | |
| # ------------------- Main ------------------- | |
| def process_audio(model_choice, mode, voice_print_path, audio_path, speaker_name, threshold=0.6): | |
| # Load full audio | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| waveform, sample_rate = ensure_16k(waveform, sample_rate) | |
| # Load selected ASR (CTC or Whisper) + meta flags | |
| model_type, asr_model, asr_processor, meta = load_asr_model(model_choice) | |
| wylie_output = bool(meta.get("wylie_output", False)) | |
| # Voice print | |
| vp_waveform, vp_sr = torchaudio.load(voice_print_path) | |
| vp_waveform, vp_sr = ensure_16k(vp_waveform, vp_sr) | |
| tmp_vp = "/tmp/voice_print_16k.wav" | |
| torchaudio.save(tmp_vp, vp_waveform, vp_sr) | |
| voice_print_embedding = embedding_inference(tmp_vp) | |
| voice_print_embedding = np.atleast_2d(voice_print_embedding) | |
| results = [] | |
| if "Diarization" in mode: | |
| if diarization_pipeline is None: | |
| return "Pyannote diarization is not available.", None, None | |
| diarization = diarization_pipeline({"waveform": waveform, "sample_rate": sample_rate}) | |
| # Run diarization - pass audio file path directly for better compatibility | |
| #diarization = diarization_pipeline(audio_path) | |
| # Correct API for pyannote 3.1+ with DiarizeOutput | |
| data = [] | |
| # Check if we have the new API (DiarizeOutput with speaker_diarization attribute) | |
| if hasattr(diarization, 'speaker_diarization'): | |
| # New API (pyannote 3.1+) - iterate over speaker_diarization | |
| for turn, speaker in diarization.speaker_diarization: | |
| data.append({ | |
| "start": turn.start, | |
| "end": turn.end, | |
| "speaker": speaker | |
| }) | |
| elif hasattr(diarization, 'itertracks'): | |
| # Old API (pyannote < 3.1) - Annotation object | |
| for segment, track, speaker in diarization.itertracks(yield_label=True): | |
| data.append({ | |
| "start": segment.start, | |
| "end": segment.end, | |
| "speaker": speaker | |
| }) | |
| else: | |
| return "Unsupported pyannote.audio version. Please check the diarization output format.", None, None | |
| if not data: | |
| return "No speaker segments found in diarization.", None, None | |
| diarization_df = pd.DataFrame(data) | |
| # Identify target speaker | |
| _, mapping, diarization_df = identify_speaker( | |
| diarization_df, audio_path, voice_print_embedding, speaker_name, embedding_inference, threshold | |
| ) | |
| for _, row in diarization_df.iterrows(): | |
| for seg_start, seg_end in split_long_segment(row['start'], row['end']): | |
| seg_waveform = waveform[:, int(seg_start*sample_rate):int(seg_end*sample_rate)] | |
| if row['identified_speaker'] == speaker_name: | |
| # Target speaker -> use selected ASR path | |
| if model_type == "ctc": | |
| transcription = transcribe_segment_ctc(seg_waveform, sample_rate, asr_model, asr_processor) | |
| else: # whisper | |
| transcription = transcribe_segment_whisper( | |
| seg_waveform, sample_rate, asr_model, asr_processor, wylie_output=wylie_output | |
| ) | |
| else: | |
| if mode == "Diarization (Target Speaker Only)": | |
| transcription = "" # skip other speakers | |
| else: | |
| # Other speakers -> Whisper Large v3 fallback (already Tibetan) | |
| transcription = transcribe_with_whisper_large(seg_waveform, sample_rate) | |
| results.append((seg_start, seg_end, transcription, row['identified_speaker'])) | |
| # Save subtitle files | |
| base_path = os.path.splitext(audio_path)[0] | |
| srt_path = f"{base_path}_identified.srt" | |
| vtt_path = f"{base_path}_identified.vtt" | |
| create_subtitle_file(results, srt_path, "srt") | |
| create_subtitle_file(results, vtt_path, "vtt") | |
| transcript_text = "\n".join([f"{sp}: {txt}" for (_, _, txt, sp) in results]) | |
| return transcript_text, srt_path, vtt_path | |
| # ------------------- Gradio ------------------- | |
| demo = gr.Interface( | |
| fn=process_audio, | |
| inputs=[ | |
| gr.Dropdown( | |
| choices=list(MODEL_OPTIONS.keys()), | |
| value="base (MMS Wav2Vec2)", | |
| label="Select ASR Model" | |
| ), | |
| gr.Radio( | |
| choices=["Diarization (Transcribe All)", "Diarization (Target Speaker Only)"], | |
| value="Diarization (Transcribe All)", | |
| label="Segmentation Method" | |
| ), | |
| gr.Audio(sources=["upload"], type="filepath", label="Voice Print Audio"), | |
| gr.Audio(sources=["upload"], type="filepath", label="Full Audio"), | |
| gr.Textbox(value="GR", label="Speaker Name for Voice Print") | |
| ], | |
| outputs=[ | |
| gr.Textbox( | |
| label="Transcript", | |
| lines=24, # height (try 20–30) | |
| max_lines=60, # optional scroll cap | |
| show_copy_button=True | |
| ), | |
| gr.File(label="SRT File"), | |
| gr.File(label="WebVTT File") | |
| ], | |
| title="STT + Speaker Identification", | |
| description=( | |
| "Choose an ASR model (MMS Wav2Vec2 or your fine-tuned Whisper). " | |
| "If you pick 'Whisper (Wylie, default tokenizer)', the output will be converted from Wylie to Tibetan Unicode via pyewts. " | |
| "Target speaker → chosen model; other speakers → Whisper Large v3." | |
| ) | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |