ganga4364's picture
Update app.py
90bee82 verified
import gradio as gr
import torch
import torchaudio
import numpy as np
import os
import pandas as pd
from datetime import timedelta
from pathlib import Path
from transformers import (
Wav2Vec2ForCTC,
Wav2Vec2Processor,
WhisperProcessor,
WhisperForConditionalGeneration
)
from pyannote.audio import Pipeline, Model, Inference
from scipy.spatial.distance import cdist
import torchaudio.transforms as T
# --- Optional Wylie→Tibetan converter (pyewts) ---
try:
from pyewts import pyewts
_EWTSCONV = pyewts()
except Exception:
_EWTSCONV = None
print("[WARN] pyewts not available. Wylie→Tibetan conversion will be skipped.")
def ewts_to_unicode(text: str) -> str:
if _EWTSCONV is None:
return text
try:
return _EWTSCONV.toUnicode(text)
except Exception:
return text
# ------------------- Audio Utils -------------------
def ensure_16k(waveform, sr, target_sr=16000):
"""Ensure waveform is 16kHz mono."""
if waveform.ndim > 1 and waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True) # stereo -> mono
if sr != target_sr:
resampler = T.Resample(sr, target_sr)
waveform = resampler(waveform)
return waveform, target_sr
# ------------------- Config -------------------
CACHE_DIR = "./models_cache"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] Using device: {DEVICE}")
HF_TOKEN = os.getenv("HF_TOKEN")
# --- ASR model options (type, repo, meta) ---
# meta can hold flags like { "wylie_output": True } to post-process Whisper output
MODEL_OPTIONS = {
# Wav2Vec2 / MMS (CTC) models
"v7 (MMS Wav2Vec2)": ("ctc", "ganga4364/mms_300_Garchen_Rinpoche-v7-scracth-Checkpoint-23000", {}),
"v6 (MMS Wav2Vec2)": ("ctc", "ganga4364/mms_300_Garchen_Rinpoche-v6-ft-Checkpoint-25000", {}),
"v5 (MMS Wav2Vec2)": ("ctc", "ganga4364/mms_300_Garchen_Rinpoche-v5-base-Checkpoint-28000", {}),
"v2 (MMS Wav2Vec2)": ("ctc", "openpecha/mms_300_Garchen_Rinpoche-v2-Checkpoint-43000", {}),
"v4 (MMS Wav2Vec2)": ("ctc", "openpecha/mms_300_Garchen_Rinpoche-v4-Checkpoint-22000", {}),
"v3 (MMS Wav2Vec2)": ("ctc", "openpecha/mms_300_Garchen_Rinpoche-v3-Checkpoint-25000", {}),
"v1 (MMS Wav2Vec2)": ("ctc", "openpecha/Garchen_Rinpoche_stt", {}),
"base (MMS Wav2Vec2)": ("ctc", "openpecha/general_stt_base_model", {}),
# Whisper (seq2seq) models
"Whisper (Wylie, default tokenizer)": (
"whisper",
"ganga4364/whisper-small-tibetan-wylie-checkpoint-4000",
{"wylie_output": True} # convert to Tibetan via pyewts
),
"Whisper (Tibetan, added tokens)": (
"whisper",
"ganga4364/whisper-small-latin-added-tibetan-checkpoint-4000",
{"wylie_output": False} # already Tibetan script
),
}
# Cache for ASR models
asr_cache = {}
def load_asr_model(choice):
"""Load either a CTC (Wav2Vec2) or Whisper model + processor based on dropdown choice."""
if choice not in MODEL_OPTIONS:
raise ValueError(f"Unknown model choice: {choice}")
model_type, repo, meta = MODEL_OPTIONS[choice]
if choice not in asr_cache:
print(f"[INFO] Loading ASR model: {choice} ({model_type}) -> {repo}")
if model_type == "ctc":
model = Wav2Vec2ForCTC.from_pretrained(repo, cache_dir=CACHE_DIR).to(DEVICE)
processor = Wav2Vec2Processor.from_pretrained(repo, cache_dir=CACHE_DIR)
model.eval()
elif model_type == "whisper":
processor = WhisperProcessor.from_pretrained(
repo, cache_dir=CACHE_DIR, language="Tibetan", task="transcribe"
)
model = WhisperForConditionalGeneration.from_pretrained(repo, cache_dir=CACHE_DIR).to(DEVICE)
model.eval()
else:
raise ValueError(f"Unsupported model type: {model_type}")
asr_cache[choice] = (model_type, model, processor, meta)
return asr_cache[choice]
# ------------------- Whisper Large v3 (fallback for other speakers) -------------------
print("[INFO] Loading Whisper Large V3 for other speakers...")
whisper_model_lg = WhisperForConditionalGeneration.from_pretrained(
"openai/whisper-large-v3", cache_dir=CACHE_DIR
).to(DEVICE)
whisper_proc_lg = WhisperProcessor.from_pretrained("openai/whisper-large-v3", cache_dir=CACHE_DIR)
whisper_model_lg.eval()
def transcribe_with_whisper_large(waveform, sr):
waveform, sr = ensure_16k(waveform, sr)
if waveform.shape[1] < 400:
return ""
inputs = whisper_proc_lg(waveform.squeeze(), sampling_rate=sr, return_tensors="pt")
input_features = inputs["input_features"].to(DEVICE)
forced_ids = whisper_proc_lg.get_decoder_prompt_ids(language="Tibetan", task="transcribe")
with torch.no_grad():
pred_ids = whisper_model_lg.generate(
input_features, forced_decoder_ids=forced_ids, num_beams=4, max_length=225
)
return whisper_proc_lg.batch_decode(pred_ids, skip_special_tokens=True)[0].strip()
# ------------------- Pyannote -------------------
try:
diarization_pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1", token=HF_TOKEN, cache_dir=CACHE_DIR
).to(DEVICE)
print("Pyannote diarization loaded")
except Exception as e:
diarization_pipeline = None
print(f"[WARN] Pyannote diarization not available: {e}")
# Embedding model for voice print
embedding_model = Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM", cache_dir=CACHE_DIR)
embedding_inference = Inference(embedding_model, window="whole")
# ------------------- Helpers -------------------
MAX_SEGMENT_SEC = 15
def format_timestamp(seconds, format_type="srt"):
td = timedelta(seconds=seconds)
hours, remainder = divmod(td.seconds, 3600)
minutes, seconds = divmod(remainder, 60)
milliseconds = round(td.microseconds / 1000)
if format_type == "srt":
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
else:
return f"{hours:02d}:{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
def create_subtitle_file(timestamps_with_text, output_path, format_type="srt"):
with open(output_path, "w", encoding="utf-8") as f:
if format_type == "vtt":
f.write("WEBVTT\n\n")
for i, (start, end, text, speaker) in enumerate(timestamps_with_text, 1):
if format_type == "srt":
f.write(f"{i}\n")
f.write(f"{format_timestamp(start)} --> {format_timestamp(end)}\n")
f.write(f"{speaker}: {text}\n\n")
else:
f.write(f"{format_timestamp(start, 'vtt')} --> {format_timestamp(end, 'vtt')}\n")
f.write(f"<v {speaker}>{text}\n\n")
def split_long_segment(start, end, max_length=MAX_SEGMENT_SEC):
segments = []
total_duration = end - start
if total_duration <= max_length:
return [(start, end)]
current = start
while current < end:
seg_end = min(current + max_length, end)
segments.append((current, seg_end))
current = seg_end
return segments
# --- CTC inference (Wav2Vec2/MMS) ---
def transcribe_segment_ctc(waveform, sr, model, processor):
waveform, sr = ensure_16k(waveform, sr)
if waveform.shape[1] < 400:
return ""
inputs = processor(waveform.squeeze(), sampling_rate=sr, return_tensors="pt", padding=True)
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
with torch.no_grad():
logits = model(**inputs).logits
pred_ids = torch.argmax(logits, dim=-1)
return processor.decode(pred_ids[0].cpu())
# --- Whisper inference (seq2seq) with optional Wylie→Tibetan conversion ---
def transcribe_segment_whisper(waveform, sr, model, processor, wylie_output: bool = False):
waveform, sr = ensure_16k(waveform, sr)
if waveform.shape[1] < 400:
return ""
inputs = processor(waveform.squeeze(), sampling_rate=sr, return_tensors="pt")
input_features = inputs["input_features"].to(DEVICE)
forced_ids = processor.get_decoder_prompt_ids(language="Tibetan", task="transcribe")
with torch.no_grad():
pred_ids = model.generate(
input_features,
forced_decoder_ids=forced_ids,
num_beams=4,
max_length=225
)
text = processor.batch_decode(pred_ids, skip_special_tokens=True)[0].strip()
# If this Whisper model outputs Wylie, convert to Tibetan Unicode
if wylie_output:
text = ewts_to_unicode(text)
return text
# ------------------- Speaker Identification -------------------
def identify_speaker(diarization_df, audio_path, voice_print_embedding, speaker_name, inference, threshold=0.6, n_segments=3):
waveform, sr = torchaudio.load(audio_path)
speaker_distances = {}
for speaker in diarization_df['speaker'].unique():
sp_df = diarization_df[diarization_df['speaker'] == speaker].copy()
sp_df['duration'] = sp_df['end'] - sp_df['start']
sp_df = sp_df.sort_values(by='duration', ascending=False).head(n_segments)
distances = []
for _, row in sp_df.iterrows():
start, end = int(row['start']*sr), int(row['end']*sr)
segment = waveform[:, start:end]
seg_path = f"/tmp/{speaker}_{start}_{end}.wav"
torchaudio.save(seg_path, segment, sr)
try:
seg_embedding = inference(seg_path)
seg_embedding = np.atleast_2d(seg_embedding)
dist = cdist(seg_embedding, voice_print_embedding, metric="cosine")[0, 0]
distances.append(dist)
except Exception as e:
print(f"Error embedding segment {speaker} {row['start']}-{row['end']}: {e}")
if distances:
speaker_distances[speaker] = np.mean(distances)
if not speaker_distances:
return None, {}, diarization_df
best_match = min(speaker_distances, key=speaker_distances.get)
min_distance = speaker_distances[best_match]
if min_distance <= threshold:
mapping = {sp: speaker_name if sp == best_match else f"Other Speaker {i}"
for i, sp in enumerate(speaker_distances.keys())}
else:
mapping = {sp: f"Speaker {i}" for i, sp in enumerate(speaker_distances.keys())}
diarization_df['identified_speaker'] = diarization_df['speaker'].map(mapping)
return best_match, mapping, diarization_df
# ------------------- Main -------------------
def process_audio(model_choice, mode, voice_print_path, audio_path, speaker_name, threshold=0.6):
# Load full audio
waveform, sample_rate = torchaudio.load(audio_path)
waveform, sample_rate = ensure_16k(waveform, sample_rate)
# Load selected ASR (CTC or Whisper) + meta flags
model_type, asr_model, asr_processor, meta = load_asr_model(model_choice)
wylie_output = bool(meta.get("wylie_output", False))
# Voice print
vp_waveform, vp_sr = torchaudio.load(voice_print_path)
vp_waveform, vp_sr = ensure_16k(vp_waveform, vp_sr)
tmp_vp = "/tmp/voice_print_16k.wav"
torchaudio.save(tmp_vp, vp_waveform, vp_sr)
voice_print_embedding = embedding_inference(tmp_vp)
voice_print_embedding = np.atleast_2d(voice_print_embedding)
results = []
if "Diarization" in mode:
if diarization_pipeline is None:
return "Pyannote diarization is not available.", None, None
diarization = diarization_pipeline({"waveform": waveform, "sample_rate": sample_rate})
# Run diarization - pass audio file path directly for better compatibility
#diarization = diarization_pipeline(audio_path)
# Correct API for pyannote 3.1+ with DiarizeOutput
data = []
# Check if we have the new API (DiarizeOutput with speaker_diarization attribute)
if hasattr(diarization, 'speaker_diarization'):
# New API (pyannote 3.1+) - iterate over speaker_diarization
for turn, speaker in diarization.speaker_diarization:
data.append({
"start": turn.start,
"end": turn.end,
"speaker": speaker
})
elif hasattr(diarization, 'itertracks'):
# Old API (pyannote < 3.1) - Annotation object
for segment, track, speaker in diarization.itertracks(yield_label=True):
data.append({
"start": segment.start,
"end": segment.end,
"speaker": speaker
})
else:
return "Unsupported pyannote.audio version. Please check the diarization output format.", None, None
if not data:
return "No speaker segments found in diarization.", None, None
diarization_df = pd.DataFrame(data)
# Identify target speaker
_, mapping, diarization_df = identify_speaker(
diarization_df, audio_path, voice_print_embedding, speaker_name, embedding_inference, threshold
)
for _, row in diarization_df.iterrows():
for seg_start, seg_end in split_long_segment(row['start'], row['end']):
seg_waveform = waveform[:, int(seg_start*sample_rate):int(seg_end*sample_rate)]
if row['identified_speaker'] == speaker_name:
# Target speaker -> use selected ASR path
if model_type == "ctc":
transcription = transcribe_segment_ctc(seg_waveform, sample_rate, asr_model, asr_processor)
else: # whisper
transcription = transcribe_segment_whisper(
seg_waveform, sample_rate, asr_model, asr_processor, wylie_output=wylie_output
)
else:
if mode == "Diarization (Target Speaker Only)":
transcription = "" # skip other speakers
else:
# Other speakers -> Whisper Large v3 fallback (already Tibetan)
transcription = transcribe_with_whisper_large(seg_waveform, sample_rate)
results.append((seg_start, seg_end, transcription, row['identified_speaker']))
# Save subtitle files
base_path = os.path.splitext(audio_path)[0]
srt_path = f"{base_path}_identified.srt"
vtt_path = f"{base_path}_identified.vtt"
create_subtitle_file(results, srt_path, "srt")
create_subtitle_file(results, vtt_path, "vtt")
transcript_text = "\n".join([f"{sp}: {txt}" for (_, _, txt, sp) in results])
return transcript_text, srt_path, vtt_path
# ------------------- Gradio -------------------
demo = gr.Interface(
fn=process_audio,
inputs=[
gr.Dropdown(
choices=list(MODEL_OPTIONS.keys()),
value="base (MMS Wav2Vec2)",
label="Select ASR Model"
),
gr.Radio(
choices=["Diarization (Transcribe All)", "Diarization (Target Speaker Only)"],
value="Diarization (Transcribe All)",
label="Segmentation Method"
),
gr.Audio(sources=["upload"], type="filepath", label="Voice Print Audio"),
gr.Audio(sources=["upload"], type="filepath", label="Full Audio"),
gr.Textbox(value="GR", label="Speaker Name for Voice Print")
],
outputs=[
gr.Textbox(
label="Transcript",
lines=24, # height (try 20–30)
max_lines=60, # optional scroll cap
show_copy_button=True
),
gr.File(label="SRT File"),
gr.File(label="WebVTT File")
],
title="STT + Speaker Identification",
description=(
"Choose an ASR model (MMS Wav2Vec2 or your fine-tuned Whisper). "
"If you pick 'Whisper (Wylie, default tokenizer)', the output will be converted from Wylie to Tibetan Unicode via pyewts. "
"Target speaker → chosen model; other speakers → Whisper Large v3."
)
)
if __name__ == "__main__":
demo.launch(share=True)