import io import datetime import tempfile import gradio as gr import spaces from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM # Models TRANSCRIBE_MODEL = "kotoba-tech/kotoba-whisper-v1.0" TRANSLATE_MODEL = "tencent/Hunyuan-MT-7B" # Load pipelines transcriber = pipeline( "automatic-speech-recognition", model=TRANSCRIBE_MODEL, generate_kwargs={"language": "japanese", "task": "transcribe"} ) tokenizer_t = AutoTokenizer.from_pretrained(TRANSLATE_MODEL) model_t = AutoModelForCausalLM.from_pretrained(TRANSLATE_MODEL, device_map="auto") def format_srt(chunks, texts): """Build SRT using chunks list (with start/end) and corresponding texts list.""" srt_lines = [] for idx, (chunk, text) in enumerate(zip(chunks, texts), start=1): start = chunk.get("start") or (chunk.get("timestamp")[0] if "timestamp" in chunk else None) end = chunk.get("end") or (chunk.get("timestamp")[1] if "timestamp" in chunk else None) if start is None or end is None: continue start_ts = datetime.timedelta(seconds=float(start)) end_ts = datetime.timedelta(seconds=float(end)) def fmt(ts): total_seconds = int(ts.total_seconds()) hours = total_seconds // 3600 minutes = (total_seconds % 3600) // 60 seconds = total_seconds % 60 milliseconds = int((ts.total_seconds() - total_seconds) * 1000) return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}" srt_lines.append(f"{idx}") srt_lines.append(f"{fmt(start_ts)} --> {fmt(end_ts)}") srt_lines.append(text.strip()) srt_lines.append("") return "\n".join(srt_lines) def translate_text(text, target_lang="English"): """Translate text to English using Hunyuan-MT-7B.""" prompt = f"Translate the following segment into English, without explanation:\n\n{text}" # Tokenize inputs = tokenizer_t(prompt, return_tensors="pt") inputs.pop("token_type_ids", None) # ✅ Move input tensors to the same device as the model inputs = {k: v.to(model_t.device) for k, v in inputs.items()} # Generate translation outputs = model_t.generate( **inputs, max_new_tokens=512 ) result = tokenizer_t.decode(outputs[0], skip_special_tokens=True) return result.strip() @spaces.GPU def process_audio(audio_file, translate: bool): """Transcribe, optionally translate, and return subtitle paths.""" import traceback, sys try: # --- Transcription --- res = transcriber(audio_file, return_timestamps=True) full_text = res.get("text", "") chunks = res.get("chunks", []) or res.get("segments", []) orig_texts = [c.get("text", "") for c in chunks] # --- Original subtitles --- orig_srt_content = format_srt(chunks, orig_texts) with tempfile.NamedTemporaryFile(delete=False, suffix=".srt") as tmp_orig: tmp_orig.write(orig_srt_content.encode("utf-8")) orig_srt_path = tmp_orig.name # --- Optional translation --- if translate: translated_texts = [translate_text(txt) for txt in orig_texts] trans_srt_content = format_srt(chunks, translated_texts) with tempfile.NamedTemporaryFile(delete=False, suffix=".srt") as tmp_tr: tmp_tr.write(trans_srt_content.encode("utf-8")) trans_srt_path = tmp_tr.name else: trans_srt_path = None return full_text, orig_srt_path, trans_srt_path except Exception as e: print("🚨 GPU worker exception:", e, file=sys.stderr) traceback.print_exc() # Return placeholders so Gradio doesn’t crash return f"Error: {e}", None, None with gr.Blocks() as demo: gr.Markdown("## 🎙️ Audio → Text + (Optional) English SRT Translation") audio_input = gr.Audio(sources=["upload"], type="filepath", label="Upload audio (.wav or .mp3)") translate_checkbox = gr.Checkbox(label="Translate to English subtitles", value=True) process_button = gr.Button("Process Audio", variant="primary") transcript_box = gr.Textbox(label="Transcript (Original)") download_orig = gr.File(label="Download original .srt file") download_trans = gr.File(label="Download English .srt file") def wrapper(audio_path, translate_flag): full, orig_path, trans_path = process_audio(audio_path, translate_flag) if translate_flag and trans_path: return full, orig_path, trans_path else: return full, orig_path, None process_button.click( wrapper, inputs=[audio_input, translate_checkbox], outputs=[transcript_box, download_orig, download_trans] ) demo.launch()