|
|
import io |
|
|
import datetime |
|
|
import tempfile |
|
|
import gradio as gr |
|
|
import spaces |
|
|
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
|
|
|
TRANSCRIBE_MODEL = "kotoba-tech/kotoba-whisper-v1.0" |
|
|
TRANSLATE_MODEL = "tencent/Hunyuan-MT-7B" |
|
|
|
|
|
|
|
|
transcriber = pipeline( |
|
|
"automatic-speech-recognition", |
|
|
model=TRANSCRIBE_MODEL, |
|
|
generate_kwargs={"language": "japanese", "task": "transcribe"} |
|
|
) |
|
|
|
|
|
tokenizer_t = AutoTokenizer.from_pretrained(TRANSLATE_MODEL) |
|
|
model_t = AutoModelForCausalLM.from_pretrained(TRANSLATE_MODEL, device_map="auto") |
|
|
|
|
|
def format_srt(chunks, texts): |
|
|
"""Build SRT using chunks list (with start/end) and corresponding texts list.""" |
|
|
srt_lines = [] |
|
|
for idx, (chunk, text) in enumerate(zip(chunks, texts), start=1): |
|
|
start = chunk.get("start") or (chunk.get("timestamp")[0] if "timestamp" in chunk else None) |
|
|
end = chunk.get("end") or (chunk.get("timestamp")[1] if "timestamp" in chunk else None) |
|
|
if start is None or end is None: |
|
|
continue |
|
|
start_ts = datetime.timedelta(seconds=float(start)) |
|
|
end_ts = datetime.timedelta(seconds=float(end)) |
|
|
def fmt(ts): |
|
|
total_seconds = int(ts.total_seconds()) |
|
|
hours = total_seconds // 3600 |
|
|
minutes = (total_seconds % 3600) // 60 |
|
|
seconds = total_seconds % 60 |
|
|
milliseconds = int((ts.total_seconds() - total_seconds) * 1000) |
|
|
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}" |
|
|
srt_lines.append(f"{idx}") |
|
|
srt_lines.append(f"{fmt(start_ts)} --> {fmt(end_ts)}") |
|
|
srt_lines.append(text.strip()) |
|
|
srt_lines.append("") |
|
|
return "\n".join(srt_lines) |
|
|
|
|
|
def translate_text(text, target_lang="English"): |
|
|
"""Translate text to English using Hunyuan-MT-7B.""" |
|
|
prompt = f"Translate the following segment into English, without explanation:\n\n{text}" |
|
|
|
|
|
|
|
|
inputs = tokenizer_t(prompt, return_tensors="pt") |
|
|
inputs.pop("token_type_ids", None) |
|
|
|
|
|
|
|
|
inputs = {k: v.to(model_t.device) for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
outputs = model_t.generate( |
|
|
**inputs, |
|
|
max_new_tokens=512 |
|
|
) |
|
|
|
|
|
result = tokenizer_t.decode(outputs[0], skip_special_tokens=True) |
|
|
return result.strip() |
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def process_audio(audio_file, translate: bool): |
|
|
"""Transcribe, optionally translate, and return subtitle paths.""" |
|
|
import traceback, sys |
|
|
try: |
|
|
|
|
|
res = transcriber(audio_file, return_timestamps=True) |
|
|
full_text = res.get("text", "") |
|
|
chunks = res.get("chunks", []) or res.get("segments", []) |
|
|
orig_texts = [c.get("text", "") for c in chunks] |
|
|
|
|
|
|
|
|
orig_srt_content = format_srt(chunks, orig_texts) |
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".srt") as tmp_orig: |
|
|
tmp_orig.write(orig_srt_content.encode("utf-8")) |
|
|
orig_srt_path = tmp_orig.name |
|
|
|
|
|
|
|
|
if translate: |
|
|
translated_texts = [translate_text(txt) for txt in orig_texts] |
|
|
trans_srt_content = format_srt(chunks, translated_texts) |
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".srt") as tmp_tr: |
|
|
tmp_tr.write(trans_srt_content.encode("utf-8")) |
|
|
trans_srt_path = tmp_tr.name |
|
|
else: |
|
|
trans_srt_path = None |
|
|
|
|
|
return full_text, orig_srt_path, trans_srt_path |
|
|
|
|
|
except Exception as e: |
|
|
print("🚨 GPU worker exception:", e, file=sys.stderr) |
|
|
traceback.print_exc() |
|
|
|
|
|
return f"Error: {e}", None, None |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("## 🎙️ Audio → Text + (Optional) English SRT Translation") |
|
|
|
|
|
audio_input = gr.Audio(sources=["upload"], type="filepath", label="Upload audio (.wav or .mp3)") |
|
|
translate_checkbox = gr.Checkbox(label="Translate to English subtitles", value=True) |
|
|
process_button = gr.Button("Process Audio", variant="primary") |
|
|
|
|
|
transcript_box = gr.Textbox(label="Transcript (Original)") |
|
|
download_orig = gr.File(label="Download original .srt file") |
|
|
download_trans = gr.File(label="Download English .srt file") |
|
|
|
|
|
def wrapper(audio_path, translate_flag): |
|
|
full, orig_path, trans_path = process_audio(audio_path, translate_flag) |
|
|
if translate_flag and trans_path: |
|
|
return full, orig_path, trans_path |
|
|
else: |
|
|
return full, orig_path, None |
|
|
|
|
|
process_button.click( |
|
|
wrapper, |
|
|
inputs=[audio_input, translate_checkbox], |
|
|
outputs=[transcript_box, download_orig, download_trans] |
|
|
) |
|
|
|
|
|
demo.launch() |