|
|
import io |
|
|
import datetime |
|
|
import tempfile |
|
|
import gradio as gr |
|
|
import spaces |
|
|
import torch |
|
|
import traceback |
|
|
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
|
|
|
ASR_MODEL = "kotoba-tech/kotoba-whisper-v1.0" |
|
|
TRANSLATION_MODEL = "tencent/Hunyuan-MT-7B-fp8" # Non-fp8 stable version tencent/Hunyuan-MT-7B |
|
|
|
|
|
|
|
|
pipe_asr = pipeline( |
|
|
"automatic-speech-recognition", |
|
|
model=ASR_MODEL, |
|
|
generate_kwargs={"language": "japanese", "task": "transcribe"} |
|
|
) |
|
|
|
|
|
|
|
|
print("[INFO] Loading translation model...") |
|
|
tokenizer_tr = AutoTokenizer.from_pretrained(TRANSLATION_MODEL) |
|
|
model_tr = AutoModelForCausalLM.from_pretrained( |
|
|
TRANSLATION_MODEL, |
|
|
device_map="auto", |
|
|
trust_remote_code=True |
|
|
) |
|
|
model_tr.tie_weights() |
|
|
model_tr.eval() |
|
|
print("[INFO] Translation model ready (weights tied, eval mode).") |
|
|
|
|
|
|
|
|
def translate_text(text: str, target_language: str = "English") -> str: |
|
|
print(f"[DEBUG] Starting translation to {target_language}") |
|
|
try: |
|
|
prompt = f"Translate the following segment into {target_language}, without additional explanation.\n{text}" |
|
|
inputs = tokenizer_tr(prompt, return_tensors="pt") |
|
|
if "token_type_ids" in inputs: |
|
|
del inputs["token_type_ids"] |
|
|
inputs = inputs.to(model_tr.device, non_blocking=True) |
|
|
|
|
|
print("[DEBUG] Generating translation...") |
|
|
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16): |
|
|
outputs = model_tr.generate( |
|
|
**inputs, |
|
|
max_new_tokens=256, |
|
|
temperature=0.7, |
|
|
top_p=0.6, |
|
|
top_k=20, |
|
|
repetition_penalty=1.05 |
|
|
) |
|
|
|
|
|
torch.cuda.synchronize() |
|
|
torch.cuda.empty_cache() |
|
|
translated = tokenizer_tr.decode(outputs[0], skip_special_tokens=True).strip() |
|
|
|
|
|
# Remove prompt echo if present |
|
|
if translated.lower().startswith("translate"): |
|
|
translated = translated.split("\n")[-1].strip() |
|
|
|
|
|
print(f"[DEBUG] Translation OK: {translated[:100]}...") |
|
|
return translated |
|
|
|
|
|
except Exception as e: |
|
|
print("[ERROR] Translation failed!") |
|
|
traceback.print_exc() |
|
|
return f"(⚠️ Translation error: {e})" |
|
|
|
|
|
|
|
|
def format_srt(chunks, translate=False, target_language="English", batch_size=5): |
|
|
srt_lines = [] |
|
|
print(f"[INFO] Formatting {len(chunks)} subtitle chunks (translate={translate})") |
|
|
|
|
|
if not translate: |
|
|
# simpler non-translate version |
|
|
for idx, chunk in enumerate(chunks, start=1): |
|
|
start, end = chunk.get("start"), chunk.get("end") |
|
|
if start is None or end is None: |
|
|
continue |
|
|
start_ts = datetime.timedelta(seconds=float(start)) |
|
|
end_ts = datetime.timedelta(seconds=float(end)) |
|
|
orig_text = chunk.get("text", "").strip() |
|
|
srt_lines.append(f"{idx}") |
|
|
srt_lines.append(f"{start_ts} --> {end_ts}") |
|
|
srt_lines.append(orig_text) |
|
|
srt_lines.append("") |
|
|
return "\n".join(srt_lines) |
|
|
|
|
|
# === BATCHED TRANSLATION === |
|
|
texts = [c.get("text", "").strip() for c in chunks if c.get("text")] |
|
|
translations = [] |
|
|
print(f"[DEBUG] Starting batched translation ({len(texts)} total lines)") |
|
|
|
|
|
for i in range(0, len(texts), batch_size): |
|
|
batch = texts[i:i + batch_size] |
|
|
joined_text = "\n".join(batch) |
|
|
print(f"[DEBUG] Translating batch {i//batch_size + 1}/{(len(texts)+batch_size-1)//batch_size} ({len(batch)} lines)") |
|
|
try: |
|
|
translated_block = translate_text(joined_text, target_language) |
|
|
batch_lines = [l.strip() for l in translated_block.split("\n") if l.strip()] |
|
|
if len(batch_lines) < len(batch): |
|
|
# pad to keep alignment |
|
|
batch_lines += ["(translation missing)"] * (len(batch) - len(batch_lines)) |
|
|
translations.extend(batch_lines[:len(batch)]) |
|
|
except Exception as e: |
|
|
print(f"[ERROR] Batch translation failed at {i}: {e}") |
|
|
translations.extend(["(translation error)"] * len(batch)) |
|
|
|
|
|
# Align translated lines with chunks |
|
|
for idx, (chunk, trans_text) in enumerate(zip(chunks, translations), start=1): |
|
|
start, end = chunk.get("start"), chunk.get("end") |
|
|
if start is None or end is None: |
|
|
continue |
|
|
start_ts = datetime.timedelta(seconds=float(start)) |
|
|
end_ts = datetime.timedelta(seconds=float(end)) |
|
|
orig_text = chunk.get("text", "").strip() |
|
|
line_text = f"{orig_text}\n{trans_text}" |
|
|
srt_lines.append(f"{idx}") |
|
|
srt_lines.append(f"{start_ts} --> {end_ts}") |
|
|
srt_lines.append(line_text) |
|
|
srt_lines.append("") |
|
|
|
|
|
print(f"[INFO] Completed batched translation in {len(translations)} lines.") |
|
|
return "\n".join(srt_lines) |
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def process_audio(audio_file, translate, target_lang): |
|
|
print("[INFO] Starting ASR transcription...") |
|
|
result = pipe_asr(audio_file, return_timestamps=True) |
|
|
full_text = result.get("text", "") |
|
|
chunks = result.get("chunks", []) or result.get("segments", []) |
|
|
print(f"[INFO] Transcription complete. {len(chunks)} segments detected.") |
|
|
|
|
|
print("[INFO] Generating SRT output...") |
|
|
srt_content = format_srt(chunks, translate, target_lang) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".srt") as tmp: |
|
|
tmp.write(srt_content.encode("utf-8")) |
|
|
tmp_path = tmp.name |
|
|
|
|
|
print("[INFO] SRT file ready.") |
|
|
return full_text, tmp_path |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("## 🎙️ Whisper Transcription + Hunyuan Translation\nUpload → Transcribe → (optional) Translate → Download SRT") |
|
|
|
|
|
with gr.Row(): |
|
|
audio_input = gr.Audio(sources=["upload"], type="filepath", label="Upload audio (.wav / .mp3)") |
|
|
transcribe_button = gr.Button("Transcribe / Translate", variant="primary") |
|
|
|
|
|
translate_checkbox = gr.Checkbox(label="Translate to English (adds translation under original)", value=False) |
|
|
target_dropdown = gr.Dropdown( |
|
|
choices=["English", "Japanese", "Chinese", "French", "German"], |
|
|
label="Target language", |
|
|
value="English" |
|
|
) |
|
|
transcript_box = gr.Textbox(label="Transcript / Translation", lines=10) |
|
|
download_button = gr.File(label="Download .srt") |
|
|
|
|
|
transcribe_button.click( |
|
|
process_audio, |
|
|
inputs=[audio_input, translate_checkbox, target_dropdown], |
|
|
outputs=[transcript_box, download_button] |
|
|
) |
|
|
|
|
|
demo.launch() |