File size: 10,979 Bytes
65eb2f1
30a8ac1
 
65eb2f1
 
934957b
 
 
65eb2f1
30a8ac1
65eb2f1
 
 
 
30a8ac1
65eb2f1
 
30a8ac1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
934957b
 
 
 
 
 
 
 
 
 
30a8ac1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2324c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30a8ac1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
---
base_model:
- meta-llama/Llama-3.2-3B-Instruct
datasets:
- JunxiongWang/sftdatasetv3
- HuggingFaceH4/ultrafeedback_binarized
- HuggingFaceH4/orca_dpo_pairs
- JunxiongWang/llama3-ultrafeedback-armorm
model-index:
- name: X-EcoMLA-3B3B-fixed-kv816-DPO
  results: []
tags:
- alignment-handbook
- generated_from_trainer
license: apache-2.0
---

# X-EcoMLA: pcycling Pre-Trained Attention into MLA for Efficient and Extreme KV Compression

X-EcoMLA is an efficient KV cache compression technique for large language models (LLMs) proposed by AMD that upcycles transformer blocks into Multi-head Latent Attention (MLA) for extreme KV cache compression and computational efficiency.
Instead of training a MLA model from scratch, the proposed X-EcoMLA first initializes the MLA weights based on Singular Value Decomposition (SVD) of the existing transformer weights, followed by lightweight pre-training or post-training distillation. 
This model, `X-EcoMLA-3B3B-fixed-kv816-DPO`, is created by efficiently adapting the pre-trained `Llama-3.2-3B-Instruct` model conducted post-training on AMD Instinct™ MI300X GPUs. This training approach bypasses the need for costly pre-training from scratch.


## Key Takeaways
- Announcing X-EcoMLA, an efficient approach to upcycle existing transformer blocks into MLA.
- Extreme KV Cache Compression: X-EcoMLA dramatically reduces the KV cache size by 6.4x - 10.6x with only 3.6B - 7B training tokens, while preserving almost 100% of its average zero-shot performance on LM Harness tasks.
- Novel SVD Initialization: X-EcoMLA employs an efficient SVD-based weight initialization which dramatically improves the training efficiency and model performance.

## Model Composition Pipeline

The X-EcoMLA models are not trained from scratch. Instead, they are composed from powerful pre-trained Transformers through a lightweight and efficient pipeline. The creation of this model followed these stages:

| Stage             | Action                                | Description                                                                                                                                                                                  |   
|-------------------|---------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| 1. Base Model     | Llama-3.2-3B-Instruct                 | The starting point is a high-quality, pre-trained Transformer model.                                                                                                                         | 
| 2. Initialization | Structured Weight Mapping             | MLA models are initialized from the base model's weights using SVD.                                     | 
| 3. SFT            | End-to-End Knowledge Distillation     | The initialized model is fine-tuned via knowledge distillation.                                                  |  
| 4. Alignment      | Direct Preference Optimization (DPO)  | In the final stage, DPO is used to align the model's preferences, with the distilled student model itself serving as the reference model for stability.                                      | 

## Training Data 

|Stage      | Dataset                                                                   | License                |   
|-----------|---------------------------------------------------------------------------|------------------------|
| SFT       | https://huggingface.co/datasets/teknium/OpenHermes-2.5                    | Refer source materials | 
| SFT       | https://huggingface.co/datasets/tomg-group-umd/GenQA                      | CC BY-NC 4.0           | 
| SFT       | https://huggingface.co/datasets/BAAI/Infinity-Instruct                    | CC BY-SA 4.0           | 
| DPO       | https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized     | MIT                    | 
| DPO       | https://huggingface.co/datasets/HuggingFaceH4/orca_dpo_pairs              | MIT                    |  
| DPO       | https://huggingface.co/datasets/JunxiongWang/llama3-ultrafeedback-armorm  | MIT                    | 

## Getting Started

### Installation

```
git clone https://github.com/AMD-AIG-AIMA/AMD-Hybrid-Models.git
cd AMD-Hybrid-Models/X-EcoMLA
```
Then follow the installation instruction in `AMD-AIG-AIMA/AMD-Hybrid-Models` repo.

### Example Usage
Once the installation completed, we can try the following code for a quick test
```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from mla.hybrid_wrapper import MLATransformerHybridModelWrapper

checkpoint = "amd/X-EcoMLA-3B3B-fixed-kv816-DPO"

model = MLATransformerHybridModelWrapper.from_pretrained(checkpoint, torch_dtype=torch.bfloat16).cuda()
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model.eval()

# Format the prompt using the chat template
prompt = [{"role": "user", "content": "What are the benefits of hybrid language models?"}]
input_ids = tokenizer.apply_chat_template(
    prompt,
    add_generation_prompt=True,
    return_tensors='pt'
).cuda()

# Generate a response
tokens = model.generate(
    input_ids, 
    max_new_tokens=256,
    temperature=0.7,
    do_sample=True,
    eos_token_id=tokenizer.eos_token_id
)
print(tokenizer.decode(tokens[0], skip_special_tokens=False))
```

Model Evaluation:
```
python benchmark/llm_eval/lm_harness_eval.py \
    --model mla_hybrid \
    --model_args pretrained="amd/X-EcoMLA-3B3B-fixed-kv816-DPO" \
    --tasks mmlu,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa,pubmedqa,race \
    --num_fewshot 0 --device cuda --batch_size 16
```

### Model details

| Model | KV Size | Target Model| Teacher Model| Training Tokens| Pre-/Post-Training| r<sub>kv</sub>| r<sub>q</sub>| d<sub>rope</sub> | d<sub>nope</sub> |
|-------|--------:|------:|------:|------:|------:|-----:|-----:|---------:|---------:|  
|X-EcoMLA-1B1B-fixed-kv512-DPO |  53.1% |  Llama-3.2-1B-Instruct|  Llama-3.2-1B-Instruct | 7B | Post |512 | 864 | 32 | 32 | 
|X-EcoMLA-1B1B-dynamic-0.95-DPO |  54.7% |  Llama-3.2-1B-Instruct|  Llama-3.2-1B-Instruct | 7B |Post | 0.95 | 0.95 | 32 | 32 | 
|X-EcoMLA-1B8B-fixed-kv64-DPO |  9.4% |  Llama-3.2-1B-Instruct|  Llama-3.1-8B-Instruct | 7B |  Post | 64 |1424 | 32 | 32 | 
|X-EcoMLA-3B3B-fixed-kv816-DPO |  43% |  Llama-3.2-3B-Instruct|  Llama-3.2-3B-Instruct | 7B |  Post | 816 | 1536 | 64 | 64 | 
|X-EcoMLA-3B3B-dynamic-0.95-DPO |  43% |  Llama-3.2-3B-Instruct|  Llama-3.2-3B-Instruct | 7B |  Post |0.95 |0.95 | 64 | 64 | 
|X-EcoMLA-SmolLM-1.7B-fixed-kv480-Pretrain |  12.5% |  SmolLM-1.7B |  -  | 6B |  Pre | 480 | 2048 | 32 | 32 | 
|X-EcoMLA-SmolLM-1.7B1.7B-fixed-kv480-Pretrain |  12.5% |  SmolLM-1.7B |  SmolLM-1.7B | 6B |  Pre | 480 | 2048 | 32 | 32 | 
|X-EcoMLA-SmolLM-1.7B1.7B-fixed-kv480-DPO |  12.5% |  SmolLM-1.7B-Instruct |  SmolLM-1.7B-Instruct | 7B |  Post | 480 | 2048 | 32 | 32 |  

### Benchmark  results
X-EcoMLA was evaluated on the Language Model Harness benchmark for zero-shot tasks and compared against its base model and other post-training methods. The results demonstrate that Zebra-Llama provides a superior balance of performance and efficiency.
| Tasks             | Metric   |  Llama-3.2-3B-Instruct | X-EcoMLA-3B3B-fixed-kv816-DPO | X-EcoMLA-3B3B-dynamic-0.95-DPO |  
|-------------------|----------|----------------: |----------------: |----------------:| 
| arc_challenge     | acc      |    0.4369±0.0145 |    0.4753±0.0146 |   0.4710±0.0146 |                
|                   | acc_norm |    0.4590±0.0146 |    0.4821±0.0146 |   0.4846±0.0146 |                 
| arc_easy          | acc      |    0.7428±0.0090 |    0.7660±0.0087 |   0.7580±0.0088 |               
|                   | acc_norm |    0.6776±0.0096 |    0.7045±0.0094 |   0.6999±0.0094 |                 
| hellaswag         | acc      |    0.5222±0.0050 |    0.5288±0.0050 |   0.5320±0.0050 |                  
|                   | acc_norm |    0.7036±0.0046 |    0.7224±0.0045 |   0.7226±0.0045 |                 
| mmlu              | acc      |    0.6046±0.1057 |    0.5742±0.1014 |   0.5773±0.1028 |                  
| - humanities      | acc      |    0.5926±0.0826 |    0.5507±0.0843 |   0.5518±0.0851 |                 
| - other           | acc      |    0.6598±0.1118 |    0.6312±0.1011 |   0.6344±0.1070 |                 
| - social_sciences | acc      |    0.6701±0.0712 |    0.6383±0.0741 |   0.6422±0.0765 |                 
| - stem            | acc      |    0.5043±0.1122 |    0.4906±0.1089 |   0.4960±0.1071 |                 
| openbookqa        | acc      |    0.2740±0.0200 |    0.2920±0.0204 |   0.3000±0.0205 |                  
|                   | acc_norm |    0.3620±0.0215 |    0.3840±0.0218 |   0.3940±0.0219 |                 
| piqa              | acc      |    0.7606±0.0100 |    0.7573±0.0100 |   0.7579±0.0100 |                 
|                   | acc_norm |    0.7557±0.0100 |    0.7655±0.0099 |   0.7579±0.0100 |                   
| pubmedqa          | acc      |    0.6960±0.0206 |    0.6680±0.0211 |   0.6840±0.0208 |                  
| race              | acc      |    0.4077±0.0152 |    0.4622±0.0154 |   0.4632±0.0154 |                 
| winogrande        | acc      |    0.6717±0.0132 |    0.6859±0.0130 |   0.6590±0.0133 |                 


## Conclusion
X-EcoMLA demonstrates an efficient technique to upcycle pre-trained Transformers into MLA modules to compress KV cache. This work highlights the viability of post-training hybridization as a cost-effective and environmentally sustainable alternative to full retraining, paving the way for the deployment of powerful LLMs in resource-constrained environments.

## Bias, Risks, and Limitations
- This model is a research artifact and has not been evaluated for safety in production use cases.
- The model's performance is dependent on the quality of its pre-trained base model and the teacher model used during distillation. Its capabilities and biases are inherited from these sources.
- The model may generate content that is factually inaccurate, biased, or otherwise objectionable. Users should be aware of these risks and implement appropriate safeguards for their applications.
- One limitation of this work is the reliance on a strong teacher model for knowledge transfer, which may not always be available. Distillation from a teacher also adds to the resource requirements during the post-training phase.

## Citation
If you find this model useful, please consider citing the original paper:
```
@article{li2025x,
  title={X-ecomla: Upcycling pre-trained attention into mla for efficient and extreme kv compression},
  author={Li, Guihong and Rezagholizadeh, Mehdi and Yang, Mingyu and Appia, Vikram and Barsoum, Emad},
  journal={arXiv preprint arXiv:2503.11132},
  year={2025}
}
```