import io import os import datetime import tempfile import gradio as gr import spaces from transformers import pipeline # ======================== # CONFIGURATION # ======================== ASR_MODEL = "kotoba-tech/kotoba-whisper-v1.0" MODEL_MAP = { ("en", "zh"): "Helsinki-NLP/opus-mt-en-zh", ("zh", "en"): "Helsinki-NLP/opus-mt-zh-en", ("en", "ja"): "Helsinki-NLP/opus-mt-en-ja", ("ja", "en"): "Helsinki-NLP/opus-mt-ja-en", } translator_cache = {} # ======================== # MODEL LOADERS # ======================== asr = pipeline( "automatic-speech-recognition", model=ASR_MODEL, generate_kwargs={"language": "japanese", "task": "transcribe"} ) def get_translator(src_lang, tgt_lang): key = (src_lang, tgt_lang) if key not in translator_cache: if key in MODEL_MAP: translator_cache[key] = pipeline("translation", model=MODEL_MAP[key]) else: raise ValueError(f"No model for {src_lang}->{tgt_lang}") return translator_cache[key] # ======================== # UTILITIES # ======================== def safe_translate(text, src, tgt): try: if (src, tgt) in MODEL_MAP: translator = get_translator(src, tgt) return translator(text, max_length=512)[0]["translation_text"] elif (src, tgt) == ("ja", "zh") or (src, tgt) == ("zh", "ja"): mid = safe_translate(text, src, "en") return safe_translate(mid, "en", tgt) else: return text except Exception as e: return f"[Translation error: {str(e)}]" def format_srt(chunks, src_lang=None, tgt_lang=None, bilingual=False, do_translate=False): """Build SRT, optionally with translation and bilingual text.""" if not chunks or not isinstance(chunks, list): return "No timestamp data available." srt_lines = [] for idx, chunk in enumerate(chunks, start=1): start, end = None, None if "timestamp" in chunk and isinstance(chunk["timestamp"], (list, tuple)): start, end = chunk["timestamp"] elif "start" in chunk and "end" in chunk: start, end = chunk["start"], chunk["end"] else: continue start_ts = datetime.timedelta(seconds=float(start)) end_ts = datetime.timedelta(seconds=float(end)) def fmt(ts): total_seconds = int(ts.total_seconds()) hours = total_seconds // 3600 minutes = (total_seconds % 3600) // 60 seconds = total_seconds % 60 milliseconds = int((ts.total_seconds() - total_seconds) * 1000) return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}" text = chunk.get("text", "").strip() translated = "" if do_translate and tgt_lang and src_lang and src_lang != tgt_lang: translated = safe_translate(text, src_lang, tgt_lang) text = f"{text}\n{translated}" if bilingual else translated srt_lines.append(f"{idx}") srt_lines.append(f"{fmt(start_ts)} --> {fmt(end_ts)}") srt_lines.append(text) srt_lines.append("") return "\n".join(srt_lines) # ======================== # MAIN GPU TASK # ======================== @spaces.GPU def transcribe_and_translate(audio_file, src_lang, tgt_lang, bilingual, do_translate): """Run ASR and optional translation.""" result = asr(audio_file, return_timestamps=True) text = result.get("text", "") chunks = result.get("chunks", []) or result.get("segments", []) srt_content = format_srt(chunks, src_lang, tgt_lang, bilingual, do_translate) tmp_srt = tempfile.NamedTemporaryFile(delete=False, suffix=".srt") tmp_srt.write(srt_content.encode("utf-8")) tmp_srt.close() return text, tmp_srt.name # ======================== # GRADIO UI # ======================== with gr.Blocks() as demo: gr.Markdown("## 🎙️ Transcribe + (Optional) Translate Audio to Subtitles") with gr.Row(): audio_input = gr.Audio( sources=["upload"], type="filepath", label="Upload audio (.wav or .mp3)" ) with gr.Row(): src_lang = gr.Dropdown(["ja", "en", "zh"], label="Source Language", value="ja") tgt_lang = gr.Dropdown(["en", "zh", "ja"], label="Target Language", value="en") do_translate = gr.Checkbox(value=False, label="Translate to Target Language") bilingual = gr.Checkbox(value=False, label="Include Original Text (Bilingual SRT)") transcribe_btn = gr.Button("Transcribe", variant="primary") transcript_box = gr.Textbox(label="Transcript (Original Language)") srt_output = gr.File(label="Download .srt file") transcribe_btn.click( transcribe_and_translate, inputs=[audio_input, src_lang, tgt_lang, bilingual, do_translate], outputs=[transcript_box, srt_output] ) demo.launch()