MegaTronX's picture
Rename app.py to app.py.bak2
e5e7349 verified
import io
import datetime
import tempfile
import gradio as gr
import spaces
import torch
import traceback
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
# === MODEL CONFIG ===
ASR_MODEL = "kotoba-tech/kotoba-whisper-v1.0"
TRANSLATION_MODEL = "tencent/Hunyuan-MT-7B-fp8" # Non-fp8 stable version tencent/Hunyuan-MT-7B
# === LOAD ASR MODEL ===
pipe_asr = pipeline(
"automatic-speech-recognition",
model=ASR_MODEL,
generate_kwargs={"language": "japanese", "task": "transcribe"}
)
# === LOAD TRANSLATION MODEL ===
print("[INFO] Loading translation model...")
tokenizer_tr = AutoTokenizer.from_pretrained(TRANSLATION_MODEL)
model_tr = AutoModelForCausalLM.from_pretrained(
TRANSLATION_MODEL,
device_map="auto",
trust_remote_code=True
)
model_tr.tie_weights()
model_tr.eval()
print("[INFO] Translation model ready (weights tied, eval mode).")
# === TRANSLATION FUNCTION ===
def translate_text(text: str, target_language: str = "English") -> str:
print(f"[DEBUG] Starting translation to {target_language}")
try:
prompt = f"Translate the following segment into {target_language}, without additional explanation.\n{text}"
inputs = tokenizer_tr(prompt, return_tensors="pt")
if "token_type_ids" in inputs:
del inputs["token_type_ids"]
inputs = inputs.to(model_tr.device, non_blocking=True)
print("[DEBUG] Generating translation...")
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
outputs = model_tr.generate(
**inputs,
max_new_tokens=256,
temperature=0.7,
top_p=0.6,
top_k=20,
repetition_penalty=1.05
)
torch.cuda.synchronize()
torch.cuda.empty_cache()
translated = tokenizer_tr.decode(outputs[0], skip_special_tokens=True).strip()
# Remove prompt echo if present
if translated.lower().startswith("translate"):
translated = translated.split("\n")[-1].strip()
print(f"[DEBUG] Translation OK: {translated[:100]}...")
return translated
except Exception as e:
print("[ERROR] Translation failed!")
traceback.print_exc()
return f"(⚠️ Translation error: {e})"
# === FORMAT SRT (with optional translation) ===
def format_srt(chunks, translate=False, target_language="English", batch_size=5):
srt_lines = []
print(f"[INFO] Formatting {len(chunks)} subtitle chunks (translate={translate})")
if not translate:
# simpler non-translate version
for idx, chunk in enumerate(chunks, start=1):
start, end = chunk.get("start"), chunk.get("end")
if start is None or end is None:
continue
start_ts = datetime.timedelta(seconds=float(start))
end_ts = datetime.timedelta(seconds=float(end))
orig_text = chunk.get("text", "").strip()
srt_lines.append(f"{idx}")
srt_lines.append(f"{start_ts} --> {end_ts}")
srt_lines.append(orig_text)
srt_lines.append("")
return "\n".join(srt_lines)
# === BATCHED TRANSLATION ===
texts = [c.get("text", "").strip() for c in chunks if c.get("text")]
translations = []
print(f"[DEBUG] Starting batched translation ({len(texts)} total lines)")
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
joined_text = "\n".join(batch)
print(f"[DEBUG] Translating batch {i//batch_size + 1}/{(len(texts)+batch_size-1)//batch_size} ({len(batch)} lines)")
try:
translated_block = translate_text(joined_text, target_language)
batch_lines = [l.strip() for l in translated_block.split("\n") if l.strip()]
if len(batch_lines) < len(batch):
# pad to keep alignment
batch_lines += ["(translation missing)"] * (len(batch) - len(batch_lines))
translations.extend(batch_lines[:len(batch)])
except Exception as e:
print(f"[ERROR] Batch translation failed at {i}: {e}")
translations.extend(["(translation error)"] * len(batch))
# Align translated lines with chunks
for idx, (chunk, trans_text) in enumerate(zip(chunks, translations), start=1):
start, end = chunk.get("start"), chunk.get("end")
if start is None or end is None:
continue
start_ts = datetime.timedelta(seconds=float(start))
end_ts = datetime.timedelta(seconds=float(end))
orig_text = chunk.get("text", "").strip()
line_text = f"{orig_text}\n{trans_text}"
srt_lines.append(f"{idx}")
srt_lines.append(f"{start_ts} --> {end_ts}")
srt_lines.append(line_text)
srt_lines.append("")
print(f"[INFO] Completed batched translation in {len(translations)} lines.")
return "\n".join(srt_lines)
# === MAIN PIPELINE FUNCTION ===
@spaces.GPU
def process_audio(audio_file, translate, target_lang):
print("[INFO] Starting ASR transcription...")
result = pipe_asr(audio_file, return_timestamps=True)
full_text = result.get("text", "")
chunks = result.get("chunks", []) or result.get("segments", [])
print(f"[INFO] Transcription complete. {len(chunks)} segments detected.")
print("[INFO] Generating SRT output...")
srt_content = format_srt(chunks, translate, target_lang)
with tempfile.NamedTemporaryFile(delete=False, suffix=".srt") as tmp:
tmp.write(srt_content.encode("utf-8"))
tmp_path = tmp.name
print("[INFO] SRT file ready.")
return full_text, tmp_path
# === GRADIO UI ===
with gr.Blocks() as demo:
gr.Markdown("## 🎙️ Whisper Transcription + Hunyuan Translation\nUpload → Transcribe → (optional) Translate → Download SRT")
with gr.Row():
audio_input = gr.Audio(sources=["upload"], type="filepath", label="Upload audio (.wav / .mp3)")
transcribe_button = gr.Button("Transcribe / Translate", variant="primary")
translate_checkbox = gr.Checkbox(label="Translate to English (adds translation under original)", value=False)
target_dropdown = gr.Dropdown(
choices=["English", "Japanese", "Chinese", "French", "German"],
label="Target language",
value="English"
)
transcript_box = gr.Textbox(label="Transcript / Translation", lines=10)
download_button = gr.File(label="Download .srt")
transcribe_button.click(
process_audio,
inputs=[audio_input, translate_checkbox, target_dropdown],
outputs=[transcript_box, download_button]
)
demo.launch()