mT5-Small — Tamil TTS (Vocab-Pruned, Tying-Fixed)
Vocab-pruned google/mt5-small for Tamil TTS training.
Critical Fix (v2)
Weight tying is now correct. Only shared.weight is stored in the checkpoint.
lm_head.weight, encoder.embed_tokens.weight, and decoder.embed_tokens.weight
are all tied to shared.weight at load time via tie_word_embeddings=True.
Previous version saved both shared.weight and lm_head.weight, causing HuggingFace
to UNTIE them silently, leading to divergent matrices during training and exploding logits.
Why mT5 for TTS?
Text -> audio tokens is a true seq2seq problem:
- Encoder reads Tamil text (bidirectional, deep understanding)
- Decoder generates FSQ audio tokens (cross-attention to text)
- Cross-attention lets each audio token attend to specific characters
What Changed
| Original | Pruned | |
|---|---|---|
| Vocab | 250,112 | 50,001 |
| shared.weight | 250,112x512 | 50,001x512 |
| lm_head | separate | tied to shared |
| Transformer | ~44M | ~44M (all 8+8 layers untouched) |
| Total params | 556,291,456 | 69,662,592 |
| Reduction | -- | 87% |
Training Notes
IMPORTANT: When adding CV3 tokens, resize in float32 then cast:
model = AutoModelForSeq2SeqLM.from_pretrained("sharvesh007/mt5-small-tamil-tts", torch_dtype=torch.float32)
tokenizer.add_tokens([f"<cv3_{i}>" for i in range(6561)])
model.resize_token_embeddings(len(tokenizer))
# Random init for new tokens (break mean-init symmetry)
with torch.no_grad():
cv3_start = tokenizer.convert_tokens_to_ids("<cv3_0>")
d = model.config.d_model
std = (2.0 / (d + 1)) ** 0.5
model.shared.weight[cv3_start:].normal_(0, std)
# NOW cast to bf16
model = model.to(torch.bfloat16)
Usage
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("sharvesh007/mt5-small-tamil-tts")
model = AutoModelForSeq2SeqLM.from_pretrained("sharvesh007/mt5-small-tamil-tts")
ids = tokenizer.encode("வணக்கம்", add_special_tokens=False)
print(f"IDs: {ids}, max: {max(ids)}")
- Downloads last month
- -
Model tree for sharvesh007/mt5-small-tamil-tts
Base model
google/mt5-small