mT5-Small — Tamil TTS (Vocab-Pruned, Tying-Fixed)

Vocab-pruned google/mt5-small for Tamil TTS training.

Critical Fix (v2)

Weight tying is now correct. Only shared.weight is stored in the checkpoint. lm_head.weight, encoder.embed_tokens.weight, and decoder.embed_tokens.weight are all tied to shared.weight at load time via tie_word_embeddings=True.

Previous version saved both shared.weight and lm_head.weight, causing HuggingFace to UNTIE them silently, leading to divergent matrices during training and exploding logits.

Why mT5 for TTS?

Text -> audio tokens is a true seq2seq problem:

  • Encoder reads Tamil text (bidirectional, deep understanding)
  • Decoder generates FSQ audio tokens (cross-attention to text)
  • Cross-attention lets each audio token attend to specific characters

What Changed

Original Pruned
Vocab 250,112 50,001
shared.weight 250,112x512 50,001x512
lm_head separate tied to shared
Transformer ~44M ~44M (all 8+8 layers untouched)
Total params 556,291,456 69,662,592
Reduction -- 87%

Training Notes

IMPORTANT: When adding CV3 tokens, resize in float32 then cast:

model = AutoModelForSeq2SeqLM.from_pretrained("sharvesh007/mt5-small-tamil-tts", torch_dtype=torch.float32)
tokenizer.add_tokens([f"<cv3_{i}>" for i in range(6561)])
model.resize_token_embeddings(len(tokenizer))
# Random init for new tokens (break mean-init symmetry)
with torch.no_grad():
    cv3_start = tokenizer.convert_tokens_to_ids("<cv3_0>")
    d = model.config.d_model
    std = (2.0 / (d + 1)) ** 0.5
    model.shared.weight[cv3_start:].normal_(0, std)
# NOW cast to bf16
model = model.to(torch.bfloat16)

Usage

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("sharvesh007/mt5-small-tamil-tts")
model = AutoModelForSeq2SeqLM.from_pretrained("sharvesh007/mt5-small-tamil-tts")
ids = tokenizer.encode("வணக்கம்", add_special_tokens=False)
print(f"IDs: {ids}, max: {max(ids)}")
Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for sharvesh007/mt5-small-tamil-tts

Base model

google/mt5-small
Finetuned
(655)
this model