import io import datetime import tempfile import gradio as gr import spaces import torch import traceback from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM # === MODEL CONFIG === ASR_MODEL = "kotoba-tech/kotoba-whisper-v1.0" TRANSLATION_MODEL = "tencent/Hunyuan-MT-7B-fp8" # Non-fp8 stable version tencent/Hunyuan-MT-7B # === LOAD ASR MODEL === pipe_asr = pipeline( "automatic-speech-recognition", model=ASR_MODEL, generate_kwargs={"language": "japanese", "task": "transcribe"} ) # === LOAD TRANSLATION MODEL === print("[INFO] Loading translation model...") tokenizer_tr = AutoTokenizer.from_pretrained(TRANSLATION_MODEL) model_tr = AutoModelForCausalLM.from_pretrained( TRANSLATION_MODEL, device_map="auto", trust_remote_code=True ) model_tr.tie_weights() model_tr.eval() print("[INFO] Translation model ready (weights tied, eval mode).") # === TRANSLATION FUNCTION === def translate_text(text: str, target_language: str = "English") -> str: print(f"[DEBUG] Starting translation to {target_language}") try: prompt = f"Translate the following segment into {target_language}, without additional explanation.\n{text}" inputs = tokenizer_tr(prompt, return_tensors="pt") if "token_type_ids" in inputs: del inputs["token_type_ids"] inputs = inputs.to(model_tr.device, non_blocking=True) print("[DEBUG] Generating translation...") with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16): outputs = model_tr.generate( **inputs, max_new_tokens=256, temperature=0.7, top_p=0.6, top_k=20, repetition_penalty=1.05 ) torch.cuda.synchronize() torch.cuda.empty_cache() translated = tokenizer_tr.decode(outputs[0], skip_special_tokens=True).strip() # Remove prompt echo if present if translated.lower().startswith("translate"): translated = translated.split("\n")[-1].strip() print(f"[DEBUG] Translation OK: {translated[:100]}...") return translated except Exception as e: print("[ERROR] Translation failed!") traceback.print_exc() return f"(⚠️ Translation error: {e})" # === FORMAT SRT (with optional translation) === def format_srt(chunks, translate=False, target_language="English", batch_size=5): srt_lines = [] print(f"[INFO] Formatting {len(chunks)} subtitle chunks (translate={translate})") if not translate: # simpler non-translate version for idx, chunk in enumerate(chunks, start=1): start, end = chunk.get("start"), chunk.get("end") if start is None or end is None: continue start_ts = datetime.timedelta(seconds=float(start)) end_ts = datetime.timedelta(seconds=float(end)) orig_text = chunk.get("text", "").strip() srt_lines.append(f"{idx}") srt_lines.append(f"{start_ts} --> {end_ts}") srt_lines.append(orig_text) srt_lines.append("") return "\n".join(srt_lines) # === BATCHED TRANSLATION === texts = [c.get("text", "").strip() for c in chunks if c.get("text")] translations = [] print(f"[DEBUG] Starting batched translation ({len(texts)} total lines)") for i in range(0, len(texts), batch_size): batch = texts[i:i + batch_size] joined_text = "\n".join(batch) print(f"[DEBUG] Translating batch {i//batch_size + 1}/{(len(texts)+batch_size-1)//batch_size} ({len(batch)} lines)") try: translated_block = translate_text(joined_text, target_language) batch_lines = [l.strip() for l in translated_block.split("\n") if l.strip()] if len(batch_lines) < len(batch): # pad to keep alignment batch_lines += ["(translation missing)"] * (len(batch) - len(batch_lines)) translations.extend(batch_lines[:len(batch)]) except Exception as e: print(f"[ERROR] Batch translation failed at {i}: {e}") translations.extend(["(translation error)"] * len(batch)) # Align translated lines with chunks for idx, (chunk, trans_text) in enumerate(zip(chunks, translations), start=1): start, end = chunk.get("start"), chunk.get("end") if start is None or end is None: continue start_ts = datetime.timedelta(seconds=float(start)) end_ts = datetime.timedelta(seconds=float(end)) orig_text = chunk.get("text", "").strip() line_text = f"{orig_text}\n{trans_text}" srt_lines.append(f"{idx}") srt_lines.append(f"{start_ts} --> {end_ts}") srt_lines.append(line_text) srt_lines.append("") print(f"[INFO] Completed batched translation in {len(translations)} lines.") return "\n".join(srt_lines) # === MAIN PIPELINE FUNCTION === @spaces.GPU def process_audio(audio_file, translate, target_lang): print("[INFO] Starting ASR transcription...") result = pipe_asr(audio_file, return_timestamps=True) full_text = result.get("text", "") chunks = result.get("chunks", []) or result.get("segments", []) print(f"[INFO] Transcription complete. {len(chunks)} segments detected.") print("[INFO] Generating SRT output...") srt_content = format_srt(chunks, translate, target_lang) with tempfile.NamedTemporaryFile(delete=False, suffix=".srt") as tmp: tmp.write(srt_content.encode("utf-8")) tmp_path = tmp.name print("[INFO] SRT file ready.") return full_text, tmp_path # === GRADIO UI === with gr.Blocks() as demo: gr.Markdown("## 🎙️ Whisper Transcription + Hunyuan Translation\nUpload → Transcribe → (optional) Translate → Download SRT") with gr.Row(): audio_input = gr.Audio(sources=["upload"], type="filepath", label="Upload audio (.wav / .mp3)") transcribe_button = gr.Button("Transcribe / Translate", variant="primary") translate_checkbox = gr.Checkbox(label="Translate to English (adds translation under original)", value=False) target_dropdown = gr.Dropdown( choices=["English", "Japanese", "Chinese", "French", "German"], label="Target language", value="English" ) transcript_box = gr.Textbox(label="Transcript / Translation", lines=10) download_button = gr.File(label="Download .srt") transcribe_button.click( process_audio, inputs=[audio_input, translate_checkbox, target_dropdown], outputs=[transcript_box, download_button] ) demo.launch()