ffurfaro commited on
Commit
2f83c80
·
verified ·
1 Parent(s): f1768fc

Upload model + init tptt code

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ lora_delta_product_m0.5_constant/tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ license: apache-2.0
4
+ library_name: transformers
5
+ tags:
6
+ - tptt
7
+ - peft
8
+ - trust_remote_code
9
+ pipeline_tag: text-generation
10
+ base_model: google/gemma-3-270m
11
+ datasets:
12
+ - yahma/alpaca-cleaned
13
+ ---
14
+
15
+ # Titanesque-gemma-3-270m
16
+
17
+ <p align="center">
18
+ <a href="https://arxiv.org/abs/2506.17671">
19
+ <img alt="arXiv" src="https://img.shields.io/badge/arXiv-tptt-blueviolet.svg">
20
+ </a>
21
+ <a href="https://pypi.org/project/tptt/">
22
+ <img alt="PyPI" src="https://img.shields.io/pypi/v/tptt?color=orange">
23
+ </a>
24
+ <a href="https://github.com/fabienfrfr/tptt/">
25
+ <img alt="Release" src="https://img.shields.io/github/v/release/fabienfrfr/tptt?color=brightgreen">
26
+ </a>
27
+ <a href="https://fabienfrfr.github.io/tptt/">
28
+ <img alt="Documentation" src="https://img.shields.io/badge/docs-online-blue">
29
+ </a>
30
+ <a href="https://huggingface.co/ffurfaro">
31
+ <img alt="HuggingFace" src="https://img.shields.io/badge/hf-ffurfaro-yellow">
32
+ </a>
33
+ </p>
34
+
35
+ Titanesque version of `google/gemma-3-270m` with parallel linearized attention (TPTT 😊) and PEFT.
36
+
37
+ The architecture was presented in the paper [TPTT](https://huggingface.co/papers/2506.17671).
38
+
39
+
40
+ ## Model list
41
+
42
+ Classic model parameter with LiZA injection :
43
+
44
+ | Subfolder | Max Self Attn Length | Mag Weight | Cross Gate | Max Chunk Size | Bidirectional | LoRA | Description |
45
+ |-------------------------------|----------------------|------------|------------|----------------|---------------|------|-------------------------------------------------------|
46
+ | delta_rule | 8192 (default) | 0.5 | False | 64 | False | Yes | Parallel linearized attention with delta_rule operator|
47
+ | delta_rule_gelu | 8192 (default) | 0.5 | False | 64 | False | Yes | Non-linear operator with gelu activation |
48
+ | delta_product | 8192 (default) | 0.5 | False | 64 | False | Yes | Second order operator with derivative trick |
49
+ | delta_product_r | 8192 (default) | 0.5 | False | 64 | False | Yes | Second order operator with rotative trick |
50
+ | delta_product_c | 8192 (default) | 0.5 | False | 64 | False | Yes | Second order operator with combined trick |
51
+
52
+ ## Usage
53
+
54
+ ```python
55
+ from transformers import AutoModelForCausalLM, AutoTokenizer
56
+
57
+ model = AutoModelForCausalLM.from_pretrained(
58
+ "ffurfaro/Titanesque-gemma-3-270m",
59
+ subfolder="tptt_subfolder", # see in repo tree
60
+ trust_remote_code=True
61
+ )
62
+ tokenizer = AutoTokenizer.from_pretrained("ffurfaro/google/gemma-3-270m")
63
+
64
+ prompt = "Your prompt here"
65
+ inputs = tokenizer(prompt, return_tensors="pt")
66
+ outputs = model.generate(**inputs, max_new_tokens=100)
67
+ print(tokenizer.decode(outputs, skip_special_tokens=True))
68
+
69
+ ```
70
+
71
+
72
+ ## Citation & Contact
73
+
74
+ If you use TPTT in your academic work, please cite [Furfaro](https://huggingface.co/ffurfaro). For questions or support, please open an issue on the [GitHub repository](https://github.com/fabienfrfr/tptt) or contact the maintainer.
75
+
76
+
77
+ ---
__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module implements the TPTT model with linear attention (LiZA) and LoRA support.
3
+ """
4
+
5
+ from .configuration_tptt import (TpttConfig, generate_model_card,
6
+ parse_mode_name)
7
+ from .modeling_tptt import (LCache, LinearAttention, LinearAttentionOp,
8
+ LiZAttention, TpttModel, get_tptt_model,
9
+ load_tptt_safetensors, save_tptt_safetensors)
10
+ from .train_tptt import LiZACallback, SaveBestModelCallback
11
+
12
+ __all__ = [
13
+ "TpttConfig",
14
+ "TpttModel",
15
+ "get_tptt_model",
16
+ "LiZACallback",
17
+ "SaveBestModelCallback",
18
+ "LCache",
19
+ "LinearAttentionOp",
20
+ "LiZAttention",
21
+ "generate_model_card",
22
+ "LinearAttention",
23
+ "parse_mode_name",
24
+ "load_tptt_safetensors",
25
+ "save_tptt_safetensors",
26
+ ]
configuration_tptt.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=too-many-arguments, too-many-positional-arguments, too-many-instance-attributes, too-many-locals
2
+ """
3
+ Author : Fabien FURFARO
4
+ """
5
+
6
+ import logging
7
+ import os
8
+ import re
9
+ from typing import Any, Dict, List, Optional, Union
10
+ from jinja2 import Environment, FileSystemLoader
11
+
12
+ import psutil
13
+ import torch
14
+ from transformers import AutoConfig, PretrainedConfig
15
+
16
+ logger = logging.getLogger(__name__) # monitoring
17
+
18
+ # Constants
19
+ BYTES_IN_GB = 1024**3
20
+
21
+
22
+ def convert_sets_to_lists(obj):
23
+ """Convert sets to list for LoRA serialized config"""
24
+ if isinstance(obj, set):
25
+ return list(obj)
26
+ if isinstance(obj, dict):
27
+ return {k: convert_sets_to_lists(v) for k, v in obj.items()}
28
+ if isinstance(obj, (list, tuple)):
29
+ return [convert_sets_to_lists(x) for x in obj]
30
+ return obj
31
+
32
+
33
+ class TpttConfig(PretrainedConfig):
34
+ """
35
+ Configuration class for the TPTT model.
36
+ This class merges the backbone config (e.g., Llama) with custom TPTT parameters,
37
+ """
38
+
39
+ model_type = "tptt"
40
+ auto_map = {
41
+ "AutoModelForCausalLM": "modeling_tptt.TpttModel",
42
+ "AutoConfig": "configuration_tptt.TpttConfig",
43
+ }
44
+ architectures = ["TpttModel"]
45
+
46
+ RECURRENT_MODES = {
47
+ "delta_rule": {
48
+ "order": 1,
49
+ "gate_type": "k",
50
+ "linear": True,
51
+ "trick": "derivative",
52
+ },
53
+ "delta_rule_v": {
54
+ "order": 1,
55
+ "gate_type": "v",
56
+ "linear": True,
57
+ "trick": "derivative",
58
+ },
59
+ "delta_rule_kv": {
60
+ "order": 1,
61
+ "gate_type": "kv",
62
+ "linear": True,
63
+ "trick": "derivative",
64
+ },
65
+ "delta_rule_gelu": {
66
+ "order": 1,
67
+ "gate_type": "k",
68
+ "linear": False,
69
+ "trick": "derivative",
70
+ },
71
+ "delta_product": {
72
+ "order": 2,
73
+ "gate_type": "k",
74
+ "linear": True,
75
+ "trick": "derivative",
76
+ },
77
+ "delta_product_r": {
78
+ "order": 2,
79
+ "gate_type": "k",
80
+ "linear": True,
81
+ "trick": "rotative",
82
+ },
83
+ "delta_product_c": {
84
+ "order": 2,
85
+ "gate_type": "k",
86
+ "linear": True,
87
+ "trick": "combined",
88
+ },
89
+ } # Tested modes, see parse_mode_name if you want to add more
90
+
91
+ def __init__(
92
+ self,
93
+ base_model_config: Optional[Union[dict, PretrainedConfig]] = None,
94
+ base_model_name: str = "meta-llama/Llama-3.2-1B",
95
+ base_model_subfolder: Optional[str] = None,
96
+ name_or_path: Optional[str] = None,
97
+ model_task: str = "causal_lm",
98
+ target_modules_names: Optional[List[str]] = None,
99
+ operator_mode: str = "delta_rule",
100
+ use_linear_checkpoint: Optional[bool] = None,
101
+ max_self_attn_length: Optional[
102
+ int
103
+ ] = None, # unnecessary if SWA, else, standards 8192
104
+ base_scale_attn: bool = False,
105
+ mag_weight: float = 0.5, # if 1.0, use only linear operator
106
+ cross_gate: bool = False, # unlinear mixing strategy
107
+ max_chunk_size: int = 64, # 128 if adaptive chunking (longest)
108
+ linear_precision: Union[str, torch.dtype] = "float32",
109
+ lora_config: Optional[dict] = None, # only serialized accepted
110
+ padding_side: Optional[str] = None, # for tokenizer, default "right"
111
+ bidirectional: bool = False, # if True, use bidirectional attention
112
+ pooling_config: Optional[Dict[str, Any]] = None,
113
+ **kwargs,
114
+ ):
115
+ # If base_model_config is provided, load it and merge with this config
116
+ if base_model_config is not None:
117
+ if isinstance(base_model_config, PretrainedConfig):
118
+ base_model_config = base_model_config.to_dict()
119
+ else:
120
+ # Load config from Hugging Face Hub or a local path
121
+ base_model_config = AutoConfig.from_pretrained(
122
+ base_model_name, **kwargs
123
+ ).to_dict()
124
+ # Merge all backbone fields into this config
125
+ for k, v in base_model_config.items():
126
+ setattr(self, k, v)
127
+
128
+ self.base_model_name = base_model_name
129
+ self.base_model_subfolder = base_model_subfolder
130
+ self.model_task = model_task
131
+
132
+ if name_or_path is not None:
133
+ self._name_or_path = name_or_path
134
+ else:
135
+ if "/" in base_model_name:
136
+ self._name_or_path = "Titans-" + base_model_name.split("/", 1)[1]
137
+ else:
138
+ self._name_or_path = "Titans-" + base_model_name
139
+
140
+ self.target_modules_names = target_modules_names or [
141
+ "attn",
142
+ "self_attn",
143
+ "attention",
144
+ ]
145
+ self.operator_mode = operator_mode
146
+
147
+ # Detect available memory on accelerator device
148
+ if torch.cuda.is_available():
149
+ _, total_mem = torch.cuda.mem_get_info()
150
+ else:
151
+ total_mem = psutil.virtual_memory().total
152
+ total_mem_gb = total_mem / BYTES_IN_GB
153
+
154
+ self.use_linear_checkpoint = (
155
+ total_mem_gb < 16
156
+ if use_linear_checkpoint is None
157
+ else use_linear_checkpoint
158
+ )
159
+
160
+ self.base_scale_attn = base_scale_attn
161
+ self.mag_weight = mag_weight
162
+ self.cross_gate = cross_gate
163
+ self.max_chunk_size = max_chunk_size
164
+ self.max_self_attn_length = max_self_attn_length
165
+ if isinstance(linear_precision, torch.dtype):
166
+ linear_precision = str(linear_precision).replace("torch.", "")
167
+ self.linear_precision = linear_precision
168
+
169
+ self.lora_config = lora_config
170
+ if lora_config is not None:
171
+ if hasattr(self.lora_config.get("peft_type"), "value"):
172
+ self.lora_config["peft_type"] = self.lora_config["peft_type"].value
173
+ self.lora_config = convert_sets_to_lists(self.lora_config)
174
+
175
+ self.padding_side = padding_side
176
+ self.bidirectional = bidirectional
177
+ if self.bidirectional:
178
+ print("Bidirectional is enabled, need to be uncausal and unpadded.")
179
+ self.pooling_config = pooling_config
180
+
181
+ super().__init__(**kwargs) # flush unconsistend pretrained parameters (?)
182
+ # Copy class attributes to instance for serialization (save dict)
183
+ self.model_type = self.__class__.model_type
184
+ self.auto_map = self.__class__.auto_map
185
+ self.architectures = self.__class__.architectures
186
+ # Padding side configuration if not set
187
+ if self.padding_side is None:
188
+ self.padding_side = "right"
189
+ logger.info("Warning: padding_side is None, defaulting to 'right'.")
190
+ # set recurrent configuration from operator mode
191
+ if operator_mode not in self.__class__.RECURRENT_MODES:
192
+ self.recurrent_config = parse_mode_name(operator_mode)
193
+ else:
194
+ self.recurrent_config = self.__class__.RECURRENT_MODES[operator_mode]
195
+ logger.info("Using recurrent mode: %s", get_mode_name(**self.recurrent_config))
196
+
197
+
198
+ TpttConfig.register_for_auto_class()
199
+
200
+
201
+ def parse_mode_name(name: str) -> dict:
202
+ """Parse mode to recurrent config"""
203
+ if name.startswith("delta_product"):
204
+ parts = name.split("_")
205
+ # Prefix is always two words: 'delta' and 'product'
206
+ base_len = 2
207
+ order = 2
208
+ gate_type = "k"
209
+ linear = True
210
+ trick = "derivative"
211
+
212
+ idx = base_len
213
+ # Check for order (immediately after the prefix)
214
+ if len(parts) > idx and parts[idx].isdigit():
215
+ order = int(parts[idx])
216
+ idx += 1
217
+
218
+ remaining = parts[idx:]
219
+ # Trick (r/c) is always at the far right if present
220
+ if remaining and remaining[-1] in ("r", "c"):
221
+ trick = {"r": "rotative", "c": "combined"}[remaining[-1]]
222
+ remaining = remaining[:-1]
223
+ # 'gelu' comes just before the trick if present
224
+ if remaining and remaining[-1] == "gelu":
225
+ linear = False
226
+ remaining = remaining[:-1]
227
+ # If anything remains, it's the gate_type
228
+ if remaining:
229
+ gate_type = "_".join(remaining)
230
+ return {
231
+ "order": order,
232
+ "gate_type": gate_type,
233
+ "linear": linear,
234
+ "trick": trick,
235
+ }
236
+
237
+ # delta_rule[_gate][_gelu]
238
+ m = re.match(r"^delta_rule(?:_(kv|v|k))?(_gelu)?$", name)
239
+ if m:
240
+ return {
241
+ "order": 1,
242
+ "gate_type": m.group(1) if m.group(1) else "k",
243
+ "linear": not bool(m.group(2)),
244
+ "trick": "derivative",
245
+ }
246
+ raise ValueError(f"Unknown mode: {name}")
247
+
248
+
249
+ def get_mode_name(
250
+ order: int = 1, gate_type: str = "k", linear: bool = True, trick: str = "derivative"
251
+ ) -> str:
252
+ """Get recurrent mode name from parameter"""
253
+ base = (
254
+ "delta_rule"
255
+ if order == 1
256
+ else ("delta_product" if order == 2 else f"delta_product_{order}")
257
+ )
258
+ parts = []
259
+ if gate_type != "k":
260
+ parts.append(gate_type)
261
+ if not linear:
262
+ parts.append("gelu")
263
+ if order >= 2 and trick != "derivative":
264
+ parts.append({"rotative": "r", "combined": "c"}.get(trick, trick))
265
+ return base + (("_" + "_".join(parts)) if parts else "")
266
+
267
+
268
+ def render_template(template_path: str, variables: dict) -> str:
269
+ """Load and render a Jinja2 template from any file path."""
270
+ env = Environment(loader=FileSystemLoader(os.path.dirname(template_path)))
271
+ template = env.get_template(os.path.basename(template_path))
272
+ return template.render(**variables)
273
+
274
+
275
+ def write_model_card(output_path: str, content: str):
276
+ """Write the generated content into README.md."""
277
+ os.makedirs(output_path, exist_ok=True)
278
+ readme_path = os.path.join(output_path, "README.md")
279
+ with open(readme_path, "w", encoding="utf-8") as f:
280
+ f.write(content)
281
+
282
+
283
+ def generate_model_card(
284
+ output_path: str,
285
+ config: Union[dict, object],
286
+ template: Optional[
287
+ str
288
+ ], # can be "model_card" OR an absolute/relative path to a .md file
289
+ extra_variables: Optional[Dict] = None,
290
+ ):
291
+ """
292
+ Generate a README.md file from a Jinja2 template and a configuration.
293
+
294
+ - template can be either:
295
+ * a full path to a template file
296
+ * a short name (e.g., "model_card") -> will be looked up inside default_templates_dir
297
+ """
298
+ if template is None:
299
+ template = "model_card_template" # default template name
300
+ # Locate the template
301
+ if os.path.exists(template): # direct file path provided
302
+ template_path = template
303
+ else:
304
+ default_templates_dir = os.path.join(os.path.dirname(__file__), "templates")
305
+ template_path = os.path.join(default_templates_dir, f"{template}.md")
306
+
307
+ if not os.path.exists(template_path):
308
+ raise FileNotFoundError(f"Template not found: {template_path}")
309
+
310
+ variables = {
311
+ "model_id": os.path.basename(output_path),
312
+ "config": config,
313
+ }
314
+ if extra_variables:
315
+ variables.update(extra_variables)
316
+
317
+ content = render_template(template_path, variables)
318
+ write_model_card(output_path, content)
lora_delta_product_m0.5_constant/README.md ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ license: apache-2.0
4
+ library_name: transformers
5
+ tags:
6
+ - tptt
7
+ - peft
8
+ - trust_remote_code
9
+ pipeline_tag: text-generation
10
+ base_model: google/gemma-3-270m
11
+ datasets:
12
+ - yahma/alpaca-cleaned
13
+ ---
14
+
15
+ # lora_delta_product_m0.5_constant
16
+
17
+ <p align="center">
18
+ <a href="https://arxiv.org/abs/2506.17671">
19
+ <img alt="arXiv" src="https://img.shields.io/badge/arXiv-tptt-blueviolet.svg">
20
+ </a>
21
+ <a href="https://pypi.org/project/tptt/">
22
+ <img alt="PyPI" src="https://img.shields.io/pypi/v/tptt?color=orange">
23
+ </a>
24
+ <a href="https://github.com/fabienfrfr/tptt/">
25
+ <img alt="Release" src="https://img.shields.io/github/v/release/fabienfrfr/tptt?color=brightgreen">
26
+ </a>
27
+ <a href="https://fabienfrfr.github.io/tptt/">
28
+ <img alt="Documentation" src="https://img.shields.io/badge/docs-online-blue">
29
+ </a>
30
+ <a href="https://huggingface.co/ffurfaro">
31
+ <img alt="HuggingFace" src="https://img.shields.io/badge/hf-ffurfaro-yellow">
32
+ </a>
33
+ </p>
34
+
35
+ Titanesque version of `google/gemma-3-270m` with parallel linearized attention (TPTT 😊) and PEFT.
36
+
37
+ The architecture was presented in the paper [TPTT](https://huggingface.co/papers/2506.17671).
38
+
39
+
40
+ ## Model Details
41
+
42
+ - **Architecture:** ['TpttModel']
43
+ - **Base model:** google/gemma-3-270m
44
+ - **LiZA config:** operator=delta_product, mag=0.5
45
+ - **LoRA config:** r=8, alpha=16, dropout=0.05
46
+ - **torch_dtype:**
47
+
48
+ ## Usage
49
+
50
+
51
+ ```python
52
+ from transformers import AutoModelForCausalLM, AutoTokenizer
53
+
54
+ model = AutoModelForCausalLM.from_pretrained(
55
+ "ffurfaro/lora_delta_product_m0.5_constant",
56
+ trust_remote_code=True
57
+ )
58
+ tokenizer = AutoTokenizer.from_pretrained("ffurfaro/google/gemma-3-270m")
59
+
60
+ prompt = "Your prompt here"
61
+ inputs = tokenizer(prompt, return_tensors="pt")
62
+ outputs = model.generate(**inputs, max_new_tokens=100)
63
+ print(tokenizer.decode(outputs, skip_special_tokens=True))
64
+
65
+ ```
66
+
67
+ > [!IMPORTANT]
68
+ > You must specify the `subfolder` if the repo contains multiple models, see the homepage for details.
69
+
70
+ ## Training
71
+
72
+ - **Dataset:** yahma/alpaca-cleaned
73
+ - **Platform:** Kaggle
74
+ - **Hardware:** 2xT4
75
+ - **Batch size:** 2
76
+ - **Epochs:** 1.0
77
+ - **Learning rate (final):** N/A
78
+ - **Loss (final):** 2.215506216105206
79
+ - **Training runtime:** 1318.9081 sec
80
+ - **Samples per second:** 1.962
81
+ - **Steps per second:** 0.491
82
+ - **Total FLOPs:** 200871790116864.0
83
+ - **Gradient norm (final):** N/A
84
+
85
+ ## Evaluation
86
+
87
+ - **Metrics:** Training loss only (no eval yet, table soon : PiQA, ARC, Hella, Wino, GSM8K, MMLU)
88
+ - **Results:** Final training loss: 2.215506216105206
89
+
90
+
91
+ ## Citation & Contact
92
+
93
+ If you use TPTT in your academic work, please cite [Furfaro](https://huggingface.co/ffurfaro). For questions or support, please open an issue on the [GitHub repository](https://github.com/fabienfrfr/tptt) or contact the maintainer.
94
+
95
+
96
+ ---
lora_delta_product_m0.5_constant/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:296adb98a761c70cb9967520f44264a481d99a1ca96d42bc50329f8e1359b33d
3
+ size 2968848
lora_delta_product_m0.5_constant/added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<image_soft_token>": 262144
3
+ }
lora_delta_product_m0.5_constant/config.json ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_sliding_window_pattern": 6,
3
+ "architectures": [
4
+ "TpttModel"
5
+ ],
6
+ "attention_bias": false,
7
+ "attention_dropout": 0.0,
8
+ "attn_logit_softcapping": null,
9
+ "auto_map": {
10
+ "AutoConfig": "configuration_tptt.TpttConfig",
11
+ "AutoModelForCausalLM": "modeling_tptt.TpttModel"
12
+ },
13
+ "base_model_name": "google/gemma-3-270m",
14
+ "base_model_subfolder": null,
15
+ "base_scale_attn": null,
16
+ "bidirectional": false,
17
+ "cache_implementation": "hybrid",
18
+ "cross_gate": false,
19
+ "final_logit_softcapping": null,
20
+ "head_dim": 256,
21
+ "hidden_activation": "gelu_pytorch_tanh",
22
+ "hidden_size": 640,
23
+ "initializer_range": 0.02,
24
+ "intermediate_size": 2048,
25
+ "layer_types": [
26
+ "sliding_attention",
27
+ "sliding_attention",
28
+ "sliding_attention",
29
+ "sliding_attention",
30
+ "sliding_attention",
31
+ "full_attention",
32
+ "sliding_attention",
33
+ "sliding_attention",
34
+ "sliding_attention",
35
+ "sliding_attention",
36
+ "sliding_attention",
37
+ "full_attention",
38
+ "sliding_attention",
39
+ "sliding_attention",
40
+ "sliding_attention",
41
+ "sliding_attention",
42
+ "sliding_attention",
43
+ "full_attention"
44
+ ],
45
+ "linear_precision": "bfloat16",
46
+ "lora_config": {
47
+ "alpha_pattern": {},
48
+ "auto_mapping": null,
49
+ "base_model_name_or_path": null,
50
+ "bias": "none",
51
+ "corda_config": null,
52
+ "eva_config": null,
53
+ "exclude_modules": null,
54
+ "fan_in_fan_out": false,
55
+ "inference_mode": false,
56
+ "init_lora_weights": true,
57
+ "layer_replication": null,
58
+ "layers_pattern": null,
59
+ "layers_to_transform": null,
60
+ "loftq_config": {},
61
+ "lora_alpha": 16,
62
+ "lora_bias": false,
63
+ "lora_dropout": 0.05,
64
+ "megatron_config": null,
65
+ "megatron_core": "megatron.core",
66
+ "modules_to_save": null,
67
+ "peft_type": "LORA",
68
+ "r": 8,
69
+ "rank_pattern": {},
70
+ "revision": null,
71
+ "target_modules": [
72
+ "q_proj",
73
+ "v_proj",
74
+ "o_proj",
75
+ "k_proj"
76
+ ],
77
+ "task_type": "CAUSAL_LM",
78
+ "trainable_token_indices": null,
79
+ "use_dora": false,
80
+ "use_rslora": false
81
+ },
82
+ "mag_weight": 0.5,
83
+ "max_chunk_size": 64,
84
+ "max_position_embeddings": 32768,
85
+ "max_self_attn_length": null,
86
+ "model_task": "causal_lm",
87
+ "model_type": "tptt",
88
+ "num_attention_heads": 4,
89
+ "num_hidden_layers": 18,
90
+ "num_key_value_heads": 1,
91
+ "operator_mode": "delta_product",
92
+ "padding_side": "left",
93
+ "pooling_config": null,
94
+ "query_pre_attn_scalar": 256,
95
+ "recurrent_config": {
96
+ "gate_type": "k",
97
+ "linear": true,
98
+ "order": 2,
99
+ "trick": "derivative"
100
+ },
101
+ "rms_norm_eps": 1e-06,
102
+ "rope_local_base_freq": 10000.0,
103
+ "rope_scaling": null,
104
+ "rope_theta": 1000000.0,
105
+ "sliding_window": 512,
106
+ "sliding_window_pattern": 6,
107
+ "target_modules_names": [
108
+ "attn",
109
+ "self_attn",
110
+ "attention"
111
+ ],
112
+ "torch_dtype": "bfloat16",
113
+ "transformers_version": "4.52.4",
114
+ "trust_remote_code": true,
115
+ "use_bidirectional_attention": false,
116
+ "use_cache": true,
117
+ "use_linear_checkpoint": false,
118
+ "vocab_size": 262144
119
+ }
lora_delta_product_m0.5_constant/configuration_tptt.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=too-many-arguments, too-many-positional-arguments, too-many-instance-attributes, too-many-locals
2
+ """
3
+ Author : Fabien FURFARO
4
+ """
5
+
6
+ import logging
7
+ import os
8
+ import re
9
+ from typing import Any, Dict, List, Optional, Union
10
+ from jinja2 import Environment, FileSystemLoader
11
+
12
+ import psutil
13
+ import torch
14
+ from transformers import AutoConfig, PretrainedConfig
15
+
16
+ logger = logging.getLogger(__name__) # monitoring
17
+
18
+ # Constants
19
+ BYTES_IN_GB = 1024**3
20
+
21
+
22
+ def convert_sets_to_lists(obj):
23
+ """Convert sets to list for LoRA serialized config"""
24
+ if isinstance(obj, set):
25
+ return list(obj)
26
+ if isinstance(obj, dict):
27
+ return {k: convert_sets_to_lists(v) for k, v in obj.items()}
28
+ if isinstance(obj, (list, tuple)):
29
+ return [convert_sets_to_lists(x) for x in obj]
30
+ return obj
31
+
32
+
33
+ class TpttConfig(PretrainedConfig):
34
+ """
35
+ Configuration class for the TPTT model.
36
+ This class merges the backbone config (e.g., Llama) with custom TPTT parameters,
37
+ """
38
+
39
+ model_type = "tptt"
40
+ auto_map = {
41
+ "AutoModelForCausalLM": "modeling_tptt.TpttModel",
42
+ "AutoConfig": "configuration_tptt.TpttConfig",
43
+ }
44
+ architectures = ["TpttModel"]
45
+
46
+ RECURRENT_MODES = {
47
+ "delta_rule": {
48
+ "order": 1,
49
+ "gate_type": "k",
50
+ "linear": True,
51
+ "trick": "derivative",
52
+ },
53
+ "delta_rule_v": {
54
+ "order": 1,
55
+ "gate_type": "v",
56
+ "linear": True,
57
+ "trick": "derivative",
58
+ },
59
+ "delta_rule_kv": {
60
+ "order": 1,
61
+ "gate_type": "kv",
62
+ "linear": True,
63
+ "trick": "derivative",
64
+ },
65
+ "delta_rule_gelu": {
66
+ "order": 1,
67
+ "gate_type": "k",
68
+ "linear": False,
69
+ "trick": "derivative",
70
+ },
71
+ "delta_product": {
72
+ "order": 2,
73
+ "gate_type": "k",
74
+ "linear": True,
75
+ "trick": "derivative",
76
+ },
77
+ "delta_product_r": {
78
+ "order": 2,
79
+ "gate_type": "k",
80
+ "linear": True,
81
+ "trick": "rotative",
82
+ },
83
+ "delta_product_c": {
84
+ "order": 2,
85
+ "gate_type": "k",
86
+ "linear": True,
87
+ "trick": "combined",
88
+ },
89
+ } # Tested modes, see parse_mode_name if you want to add more
90
+
91
+ def __init__(
92
+ self,
93
+ base_model_config: Optional[Union[dict, PretrainedConfig]] = None,
94
+ base_model_name: str = "meta-llama/Llama-3.2-1B",
95
+ base_model_subfolder: Optional[str] = None,
96
+ name_or_path: Optional[str] = None,
97
+ model_task: str = "causal_lm",
98
+ target_modules_names: Optional[List[str]] = None,
99
+ operator_mode: str = "delta_rule",
100
+ use_linear_checkpoint: Optional[bool] = None,
101
+ max_self_attn_length: Optional[
102
+ int
103
+ ] = None, # unnecessary if SWA, else, standards 8192
104
+ base_scale_attn: bool = False,
105
+ mag_weight: float = 0.5, # if 1.0, use only linear operator
106
+ cross_gate: bool = False, # unlinear mixing strategy
107
+ max_chunk_size: int = 64, # 128 if adaptive chunking (longest)
108
+ linear_precision: Union[str, torch.dtype] = "float32",
109
+ lora_config: Optional[dict] = None, # only serialized accepted
110
+ padding_side: Optional[str] = None, # for tokenizer, default "right"
111
+ bidirectional: bool = False, # if True, use bidirectional attention
112
+ pooling_config: Optional[Dict[str, Any]] = None,
113
+ **kwargs,
114
+ ):
115
+ # If base_model_config is provided, load it and merge with this config
116
+ if base_model_config is not None:
117
+ if isinstance(base_model_config, PretrainedConfig):
118
+ base_model_config = base_model_config.to_dict()
119
+ else:
120
+ # Load config from Hugging Face Hub or a local path
121
+ base_model_config = AutoConfig.from_pretrained(
122
+ base_model_name, **kwargs
123
+ ).to_dict()
124
+ # Merge all backbone fields into this config
125
+ for k, v in base_model_config.items():
126
+ setattr(self, k, v)
127
+
128
+ self.base_model_name = base_model_name
129
+ self.base_model_subfolder = base_model_subfolder
130
+ self.model_task = model_task
131
+
132
+ if name_or_path is not None:
133
+ self._name_or_path = name_or_path
134
+ else:
135
+ if "/" in base_model_name:
136
+ self._name_or_path = "Titans-" + base_model_name.split("/", 1)[1]
137
+ else:
138
+ self._name_or_path = "Titans-" + base_model_name
139
+
140
+ self.target_modules_names = target_modules_names or [
141
+ "attn",
142
+ "self_attn",
143
+ "attention",
144
+ ]
145
+ self.operator_mode = operator_mode
146
+
147
+ # Detect available memory on accelerator device
148
+ if torch.cuda.is_available():
149
+ _, total_mem = torch.cuda.mem_get_info()
150
+ else:
151
+ total_mem = psutil.virtual_memory().total
152
+ total_mem_gb = total_mem / BYTES_IN_GB
153
+
154
+ self.use_linear_checkpoint = (
155
+ total_mem_gb < 16
156
+ if use_linear_checkpoint is None
157
+ else use_linear_checkpoint
158
+ )
159
+
160
+ self.base_scale_attn = base_scale_attn
161
+ self.mag_weight = mag_weight
162
+ self.cross_gate = cross_gate
163
+ self.max_chunk_size = max_chunk_size
164
+ self.max_self_attn_length = max_self_attn_length
165
+ if isinstance(linear_precision, torch.dtype):
166
+ linear_precision = str(linear_precision).replace("torch.", "")
167
+ self.linear_precision = linear_precision
168
+
169
+ self.lora_config = lora_config
170
+ if lora_config is not None:
171
+ if hasattr(self.lora_config.get("peft_type"), "value"):
172
+ self.lora_config["peft_type"] = self.lora_config["peft_type"].value
173
+ self.lora_config = convert_sets_to_lists(self.lora_config)
174
+
175
+ self.padding_side = padding_side
176
+ self.bidirectional = bidirectional
177
+ if self.bidirectional:
178
+ print("Bidirectional is enabled, need to be uncausal and unpadded.")
179
+ self.pooling_config = pooling_config
180
+
181
+ super().__init__(**kwargs) # flush unconsistend pretrained parameters (?)
182
+ # Copy class attributes to instance for serialization (save dict)
183
+ self.model_type = self.__class__.model_type
184
+ self.auto_map = self.__class__.auto_map
185
+ self.architectures = self.__class__.architectures
186
+ # Padding side configuration if not set
187
+ if self.padding_side is None:
188
+ self.padding_side = "right"
189
+ logger.info("Warning: padding_side is None, defaulting to 'right'.")
190
+ # set recurrent configuration from operator mode
191
+ if operator_mode not in self.__class__.RECURRENT_MODES:
192
+ self.recurrent_config = parse_mode_name(operator_mode)
193
+ else:
194
+ self.recurrent_config = self.__class__.RECURRENT_MODES[operator_mode]
195
+ logger.info("Using recurrent mode: %s", get_mode_name(**self.recurrent_config))
196
+
197
+
198
+ TpttConfig.register_for_auto_class()
199
+
200
+
201
+ def parse_mode_name(name: str) -> dict:
202
+ """Parse mode to recurrent config"""
203
+ if name.startswith("delta_product"):
204
+ parts = name.split("_")
205
+ # Prefix is always two words: 'delta' and 'product'
206
+ base_len = 2
207
+ order = 2
208
+ gate_type = "k"
209
+ linear = True
210
+ trick = "derivative"
211
+
212
+ idx = base_len
213
+ # Check for order (immediately after the prefix)
214
+ if len(parts) > idx and parts[idx].isdigit():
215
+ order = int(parts[idx])
216
+ idx += 1
217
+
218
+ remaining = parts[idx:]
219
+ # Trick (r/c) is always at the far right if present
220
+ if remaining and remaining[-1] in ("r", "c"):
221
+ trick = {"r": "rotative", "c": "combined"}[remaining[-1]]
222
+ remaining = remaining[:-1]
223
+ # 'gelu' comes just before the trick if present
224
+ if remaining and remaining[-1] == "gelu":
225
+ linear = False
226
+ remaining = remaining[:-1]
227
+ # If anything remains, it's the gate_type
228
+ if remaining:
229
+ gate_type = "_".join(remaining)
230
+ return {
231
+ "order": order,
232
+ "gate_type": gate_type,
233
+ "linear": linear,
234
+ "trick": trick,
235
+ }
236
+
237
+ # delta_rule[_gate][_gelu]
238
+ m = re.match(r"^delta_rule(?:_(kv|v|k))?(_gelu)?$", name)
239
+ if m:
240
+ return {
241
+ "order": 1,
242
+ "gate_type": m.group(1) if m.group(1) else "k",
243
+ "linear": not bool(m.group(2)),
244
+ "trick": "derivative",
245
+ }
246
+ raise ValueError(f"Unknown mode: {name}")
247
+
248
+
249
+ def get_mode_name(
250
+ order: int = 1, gate_type: str = "k", linear: bool = True, trick: str = "derivative"
251
+ ) -> str:
252
+ """Get recurrent mode name from parameter"""
253
+ base = (
254
+ "delta_rule"
255
+ if order == 1
256
+ else ("delta_product" if order == 2 else f"delta_product_{order}")
257
+ )
258
+ parts = []
259
+ if gate_type != "k":
260
+ parts.append(gate_type)
261
+ if not linear:
262
+ parts.append("gelu")
263
+ if order >= 2 and trick != "derivative":
264
+ parts.append({"rotative": "r", "combined": "c"}.get(trick, trick))
265
+ return base + (("_" + "_".join(parts)) if parts else "")
266
+
267
+
268
+ def render_template(template_path: str, variables: dict) -> str:
269
+ """Load and render a Jinja2 template from any file path."""
270
+ env = Environment(loader=FileSystemLoader(os.path.dirname(template_path)))
271
+ template = env.get_template(os.path.basename(template_path))
272
+ return template.render(**variables)
273
+
274
+
275
+ def write_model_card(output_path: str, content: str):
276
+ """Write the generated content into README.md."""
277
+ os.makedirs(output_path, exist_ok=True)
278
+ readme_path = os.path.join(output_path, "README.md")
279
+ with open(readme_path, "w", encoding="utf-8") as f:
280
+ f.write(content)
281
+
282
+
283
+ def generate_model_card(
284
+ output_path: str,
285
+ config: Union[dict, object],
286
+ template: Optional[
287
+ str
288
+ ], # can be "model_card" OR an absolute/relative path to a .md file
289
+ extra_variables: Optional[Dict] = None,
290
+ ):
291
+ """
292
+ Generate a README.md file from a Jinja2 template and a configuration.
293
+
294
+ - template can be either:
295
+ * a full path to a template file
296
+ * a short name (e.g., "model_card") -> will be looked up inside default_templates_dir
297
+ """
298
+ if template is None:
299
+ template = "model_card_template" # default template name
300
+ # Locate the template
301
+ if os.path.exists(template): # direct file path provided
302
+ template_path = template
303
+ else:
304
+ default_templates_dir = os.path.join(os.path.dirname(__file__), "templates")
305
+ template_path = os.path.join(default_templates_dir, f"{template}.md")
306
+
307
+ if not os.path.exists(template_path):
308
+ raise FileNotFoundError(f"Template not found: {template_path}")
309
+
310
+ variables = {
311
+ "model_id": os.path.basename(output_path),
312
+ "config": config,
313
+ }
314
+ if extra_variables:
315
+ variables.update(extra_variables)
316
+
317
+ content = render_template(template_path, variables)
318
+ write_model_card(output_path, content)
lora_delta_product_m0.5_constant/modeling_tptt.py ADDED
@@ -0,0 +1,1509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=too-many-lines, too-many-arguments, too-many-positional-arguments, too-many-instance-attributes, too-many-locals
2
+
3
+ """
4
+ This module implements the TPTT model with linear attention (LiZA) and LoRA support.
5
+ Author : Fabien FURFARO
6
+ TPTT : Transforming Pretrained Transformers into Titans (https://arxiv.org/abs/2506.17671)
7
+ """
8
+
9
+ import logging
10
+ import math
11
+ import os
12
+ from pathlib import Path
13
+ import re
14
+ import shutil
15
+ from functools import partial
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from einops import rearrange
21
+ from huggingface_hub import hf_hub_download, list_repo_files
22
+ from peft import LoraConfig, PeftModel, get_peft_model
23
+ from safetensors import safe_open
24
+ from safetensors.torch import save_file
25
+ from torch import nn
26
+ from torch.utils.checkpoint import checkpoint
27
+ from transformers import (
28
+ AutoConfig,
29
+ AutoModel,
30
+ AutoModelForCausalLM,
31
+ DynamicCache,
32
+ PreTrainedModel,
33
+ )
34
+ from transformers.configuration_utils import PretrainedConfig
35
+
36
+ from .configuration_tptt import TpttConfig
37
+
38
+ logger = logging.getLogger(__name__) # monitoring
39
+
40
+
41
+ class LCache:
42
+ """Cache for storing intermediate states of linear attention layers."""
43
+
44
+ def __init__(self):
45
+ """Stores per-layer intermediate states: {layer_idx: state_dict}"""
46
+ self.inputs_states: Dict[int, Dict[str, torch.Tensor]] = (
47
+ {}
48
+ ) # recurrent states and qkv buffers
49
+
50
+ def __getitem__(self, layer_idx: int) -> Optional[Dict[str, torch.Tensor]]:
51
+ """Retrieve cached state for a given layer, or None if not present"""
52
+ return self.inputs_states.get(layer_idx, None)
53
+
54
+ def update(self, layer_idx: int, **kwargs):
55
+ """Detach all tensors to avoid retaining computation graphs"""
56
+ detached_kwargs = {
57
+ k: v.detach() if isinstance(v, torch.Tensor) else v
58
+ for k, v in kwargs.items()
59
+ }
60
+ # Update or create the state for the specified layer
61
+ if layer_idx in self.inputs_states:
62
+ self.inputs_states[layer_idx].update(detached_kwargs)
63
+ else:
64
+ self.inputs_states[layer_idx] = detached_kwargs
65
+
66
+ def reset(self):
67
+ """Clear all cached states and reset the token counter"""
68
+ self.inputs_states.clear()
69
+
70
+
71
+ class CausalAvgPool1d(nn.Module):
72
+ """Causal sliding window average (uniform, no shape loss along sequence)"""
73
+
74
+ def __init__(
75
+ self, output_size: int, offsets: tuple[int] = (0, 1, 2), mode: str = "replicate"
76
+ ):
77
+ super().__init__()
78
+ self.offsets = offsets
79
+ self.mode = mode
80
+ self.pool = nn.AdaptiveAvgPool1d(output_size=output_size)
81
+
82
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
83
+ """x: [B, S, F] → [B, S, F → output_size]"""
84
+ x_ = x.transpose(1, 2) # [B, F, S]
85
+ idxs = torch.tensor(self.offsets, device=x.device)
86
+ ksize = idxs.max() - idxs.min() + 1
87
+ w = torch.zeros(ksize, device=x.device, dtype=x.dtype)
88
+ w[idxs - idxs.min()] = 1 / len(self.offsets) # Always uniform weights
89
+ kernel = w.repeat(x_.shape[1], 1).reshape(x_.shape[1], 1, ksize)
90
+ pad_left = -idxs.min().item()
91
+ pad_right = (ksize - 1) - pad_left
92
+ x_pad = F.pad(x_, (pad_left, pad_right), mode=self.mode)
93
+ y = F.conv1d(x_pad, kernel, groups=x_.shape[1]) # pylint: disable=not-callable
94
+ return self.pool(y.transpose(1, 2)) # [B, S, F → output_size]
95
+
96
+
97
+ class LinearAttention(nn.Module):
98
+ """
99
+ Linear multi-head attention layer: [B, S, D] -> [B, S, D]
100
+ Projections + gating + efficient linear attention mechanism (TPTT compatible).
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ hidden_dim: int,
106
+ num_heads: int,
107
+ head_dim: Optional[int] = None,
108
+ num_key_value_heads: Optional[int] = None,
109
+ num_key_value_groups: Optional[int] = None,
110
+ bias: bool = True,
111
+ dropout: Optional[float] = None,
112
+ linear_precision: torch.dtype = torch.float32,
113
+ padding_side: str = "right",
114
+ shared_attn: bool = False, # shared attention
115
+ layer_idx: int = 0,
116
+ operator_mode: str = "delta_rule",
117
+ use_linear_checkpoint: bool = False,
118
+ recurrent_config: Optional[Dict[str, Any]] = None,
119
+ linear_cache: Optional[LCache] = None,
120
+ max_chunk_size: int = 64,
121
+ bidirectional: bool = False, # not used if causal
122
+ pooling_config: Optional[Dict[str, Any]] = None,
123
+ ):
124
+ super().__init__()
125
+ if pooling_config is None:
126
+ pooling_config = {
127
+ "offsets": (0, 1, 2),
128
+ "mode": "replicate",
129
+ }
130
+ self.hidden_dim = hidden_dim
131
+ self.num_heads = num_heads
132
+ self.head_dim = head_dim or hidden_dim // num_heads
133
+ self.num_key_value_heads = num_key_value_heads or num_heads
134
+ self.num_key_value_groups = num_key_value_groups or (
135
+ num_heads // (num_key_value_heads or num_heads)
136
+ )
137
+ self.scaling = self.head_dim**-0.5
138
+ self.linear_precision = linear_precision
139
+ self.padding_side = padding_side
140
+
141
+ self.shared_attn = shared_attn
142
+
143
+ if not shared_attn:
144
+ self.q_proj = nn.Linear(hidden_dim, num_heads * self.head_dim, bias=bias)
145
+ self.k_proj = nn.Linear(
146
+ hidden_dim, self.num_key_value_heads * self.head_dim, bias=bias
147
+ )
148
+ self.v_proj = nn.Linear(
149
+ hidden_dim, self.num_key_value_heads * self.head_dim, bias=bias
150
+ )
151
+ self.out_proj = nn.Linear(num_heads * self.head_dim, hidden_dim, bias=bias)
152
+
153
+ self.dropout = nn.Dropout(dropout) if dropout is not None else None
154
+
155
+ self.linear_operator = LinearAttentionOp(
156
+ layer_idx=layer_idx,
157
+ operator_mode=operator_mode,
158
+ use_linear_checkpoint=use_linear_checkpoint,
159
+ recurrent_config=recurrent_config,
160
+ max_chunk_size=max_chunk_size,
161
+ linear_cache=linear_cache,
162
+ linear_precision=linear_precision,
163
+ )
164
+ self.bidirectional = bidirectional
165
+ # Causal average pooling for gating
166
+ self.pooling_config = pooling_config
167
+ self.pool_g = CausalAvgPool1d(
168
+ output_size=self.head_dim * self.num_key_value_heads, **pooling_config
169
+ )
170
+
171
+ def forward(
172
+ self,
173
+ x: Union[List[torch.Tensor], torch.Tensor],
174
+ attn_mask: Optional[torch.Tensor] = None,
175
+ out_proj: Optional[nn.Module] = None,
176
+ **kwargs: Any,
177
+ ) -> torch.Tensor:
178
+ """
179
+ Forward pass for linear attention. Input shape: [B, S, D], output [B, S, D].
180
+ """
181
+
182
+ if not self.shared_attn:
183
+ hidden_states = x[0] if isinstance(x, (list, tuple)) else x
184
+ # Projections
185
+ q = self.q_proj(hidden_states)
186
+ k = self.k_proj(hidden_states)
187
+ v = self.v_proj(hidden_states)
188
+ out_proj = self.out_proj
189
+ else:
190
+ # Shared attention <=> no projections here
191
+ q, k, v = x[0], x[1], x[2]
192
+ out_proj = self.out_proj if out_proj is None else out_proj
193
+
194
+ # get dtype and device
195
+ final_dtype, final_device = q.dtype, q.device
196
+ # Masking if needed
197
+ if attn_mask is not None:
198
+ v = apply_linear_attention_mask(attn_mask, v, self.padding_side)
199
+
200
+ # Forget and Write Gating for linear attn (abusive term)
201
+ f_g, w_g = self.pool_g(k), self.pool_g(v)
202
+
203
+ # Reshape for multi-head
204
+ q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads)
205
+ k = rearrange(k, "b n (h d) -> b h n d", h=self.num_key_value_heads)
206
+ v = rearrange(v, "b n (h d) -> b h n d", h=self.num_key_value_heads)
207
+
208
+ f_g = rearrange(f_g, "b n (h m) -> b h n m", h=self.num_key_value_heads)
209
+ w_g = rearrange(w_g, "b n (h m) -> b h n m", h=self.num_key_value_heads)
210
+
211
+ # Repeat for GQA
212
+ k = k.repeat_interleave(self.num_key_value_groups, dim=1)
213
+ v = v.repeat_interleave(self.num_key_value_groups, dim=1)
214
+
215
+ f_g = f_g.repeat_interleave(self.num_key_value_groups, dim=1)
216
+ w_g = w_g.repeat_interleave(self.num_key_value_groups, dim=1)
217
+
218
+ ## DeltaNet-style: Silu activation and normalization
219
+ q = F.normalize(F.silu(q), p=2, dim=-1, eps=1e-6)
220
+ k = F.normalize(F.silu(k), p=2, dim=-1, eps=1e-6)
221
+
222
+ ## linear stability part
223
+ v = ensure_stability(v * self.scaling, min_val=-1e4, max_val=1e4)
224
+
225
+ # Apply sigmoid to forget and write gates
226
+ f_g = torch.clamp(torch.sigmoid(f_g), min=1e-6, max=1 - 1e-6)
227
+ w_g = torch.clamp(torch.sigmoid(w_g), min=1e-6, max=1 - 1e-6)
228
+
229
+ # Convert to linear_precision (float32) for numerical stability and get model dtype
230
+ q, k, v, f_g, w_g = (
231
+ x.to(self.linear_precision).contiguous() for x in (q, k, v, f_g, w_g)
232
+ )
233
+ g = (f_g, w_g)
234
+
235
+ # Linear Attention Core, output: [B, H, S, d]
236
+ if self.bidirectional: # Work only with uncausal attention
237
+ # Forward direction
238
+ out_forward = self.linear_operator(q, k, v, g, **kwargs)
239
+ # Backward direction: flip the input sequence on the time dimension (dim=2)
240
+ kwargs_bwd = kwargs.copy()
241
+ kwargs_bwd["use_cache"] = False
242
+ out_backward = self.linear_operator(
243
+ torch.flip(q, dims=[2]),
244
+ torch.flip(k, dims=[2]),
245
+ torch.flip(v, dims=[2]),
246
+ tuple(torch.flip(t, dims=[2]) for t in g),
247
+ **kwargs_bwd,
248
+ )
249
+ # Flip the output back to restore proper order
250
+ out_backward = torch.flip(out_backward, dims=[2])
251
+ # Fusion: here, simple addition
252
+ out = out_forward + out_backward
253
+ else:
254
+ out = self.linear_operator(q, k, v, g, **kwargs)
255
+
256
+ # Merge heads and project: [B, H, S, d] -> [B, S, H*d] -> Out proj
257
+ out = rearrange(out, "b h s d -> b s (h d)")
258
+ # Normalize output (RMS norm). Note: bidirectional compatibility
259
+ out = out / out.pow(2).mean(dim=-1, keepdim=True).add(1e-6).sqrt()
260
+ # Ensure dtype and device consistency
261
+ out = out.to(dtype=final_dtype, device=final_device)
262
+ # Apply output projection
263
+ out = out_proj(out) # [B, S, D]
264
+ out = ensure_stability(out, min_val=-1e4, max_val=1e4)
265
+ # Apply dropout if specified
266
+ if self.dropout is not None:
267
+ out = self.dropout(out)
268
+ return out
269
+
270
+
271
+ class LiZAttention(nn.Module):
272
+ """LiZA Linear Attention module, mixing linear and vanilla attention."""
273
+
274
+ def __init__(
275
+ self,
276
+ base_attn: nn.Module,
277
+ layer_idx: int,
278
+ base_config: PretrainedConfig, # Backbone Config
279
+ linear_cache: Optional[LCache] = None,
280
+ operator_mode: str = "delta_rule",
281
+ use_linear_checkpoint: bool = False,
282
+ recurrent_config: Optional[Dict[str, Any]] = None,
283
+ max_self_attn_length: Optional[int] = None, # unnecessary
284
+ base_scale_attn: bool = False,
285
+ mag_weight: float = 0.5,
286
+ cross_gate: bool = False,
287
+ max_chunk_size: int = 64,
288
+ linear_precision: Union[str, torch.dtype] = "float32",
289
+ padding_side: str = "right", # for tokenizer
290
+ disable_linear_attn: bool = False,
291
+ bidirectional: bool = False, # if True, use bidirectional attention
292
+ pooling_config: Optional[Dict[str, Any]] = None,
293
+ ):
294
+ super().__init__()
295
+ if isinstance(linear_precision, str):
296
+ linear_precision = getattr(torch, linear_precision)
297
+ self.linear_precision = linear_precision
298
+ self.base_attn: nn.Module = base_attn
299
+ self.base_config = base_config
300
+ self.layer_idx = layer_idx
301
+ self.max_self_attn_length = max_self_attn_length
302
+ self.base_scale_attn = base_scale_attn
303
+ self.mag_weight = mag_weight
304
+ self.cross_gate = cross_gate
305
+ self.max_chunk_size = max_chunk_size
306
+ self.linear_precision = linear_precision
307
+ self.padding_side = padding_side
308
+ self.disable_linear_attn = disable_linear_attn
309
+
310
+ (
311
+ self.num_heads,
312
+ self.head_dim,
313
+ self.num_key_value_heads,
314
+ self.num_key_value_groups,
315
+ self.hidden_dim,
316
+ ) = self._get_attention_parameters(base_attn, base_config)
317
+ self.scaling = self.head_dim**-0.5
318
+
319
+ self.linear_attn = LinearAttention(
320
+ layer_idx=layer_idx,
321
+ shared_attn=True,
322
+ operator_mode=operator_mode,
323
+ use_linear_checkpoint=use_linear_checkpoint,
324
+ recurrent_config=recurrent_config,
325
+ hidden_dim=self.hidden_dim,
326
+ num_heads=self.num_heads,
327
+ head_dim=self.head_dim,
328
+ num_key_value_heads=self.num_key_value_heads,
329
+ num_key_value_groups=self.num_key_value_groups,
330
+ linear_precision=linear_precision,
331
+ linear_cache=linear_cache,
332
+ max_chunk_size=max_chunk_size,
333
+ padding_side=padding_side,
334
+ bidirectional=bidirectional,
335
+ pooling_config=pooling_config,
336
+ )
337
+
338
+ def _get_attention_parameters(
339
+ self, base_attn: nn.Module, base_config: PretrainedConfig
340
+ ) -> Tuple[Optional[int], Optional[int], Optional[int], Optional[int]]:
341
+ """Retrieve the attention parameters from the base attention module."""
342
+ # first order base attention module and second order config
343
+ num_heads = (
344
+ getattr(base_attn, "num_heads", None)
345
+ or getattr(base_attn, "num_q_heads", None)
346
+ or getattr(base_config, "num_heads", None)
347
+ or getattr(base_config, "num_attention_heads", None)
348
+ )
349
+ head_dim = (
350
+ getattr(base_attn, "head_dim", None)
351
+ or getattr(base_attn, "attention_head_size", None)
352
+ or getattr(base_config, "head_dim", None)
353
+ or (
354
+ getattr(base_config, "hidden_size", None) // num_heads
355
+ if num_heads and getattr(base_config, "hidden_size", None)
356
+ else None
357
+ )
358
+ )
359
+ num_key_value_heads = (
360
+ getattr(base_attn, "num_kv_heads", None)
361
+ or getattr(base_attn, "num_k_heads", None)
362
+ or getattr(base_config, "num_key_value_heads", None)
363
+ or num_heads # fallback
364
+ )
365
+ num_key_value_groups = getattr(base_attn, "num_key_value_groups", None) or (
366
+ num_heads // num_key_value_heads if num_heads and num_key_value_heads else 1
367
+ )
368
+ hidden_dim = getattr(base_config, "hidden_size", None) or head_dim * num_heads
369
+ return (
370
+ num_heads,
371
+ head_dim,
372
+ num_key_value_heads,
373
+ num_key_value_groups,
374
+ hidden_dim,
375
+ )
376
+
377
+ def _apply_shared_projections(
378
+ self, hidden_states: torch.Tensor
379
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, nn.Module]:
380
+ base_attn = self.base_attn
381
+ if hasattr(base_attn, "q_proj"):
382
+ # LLama, OLMO and Mistral style
383
+ q = base_attn.q_proj(hidden_states)
384
+ k = base_attn.k_proj(hidden_states)
385
+ v = base_attn.v_proj(hidden_states)
386
+ out_proj = base_attn.o_proj
387
+ elif hasattr(base_attn, "qkv_proj"):
388
+ # OpenELM and GPT-Neo style : QKV fused, split on the last dimension
389
+ qkv = base_attn.qkv_proj(hidden_states)
390
+ q, k, v = split_qkv(base_attn, qkv)
391
+ out_proj = base_attn.out_proj
392
+ elif hasattr(base_attn, "c_attn") and hasattr(base_attn, "c_proj"):
393
+ # GPT-2 style
394
+ qkv = base_attn.c_attn(hidden_states)
395
+ q, k, v = qkv.chunk(3, dim=-1)
396
+ out_proj = base_attn.c_proj
397
+ elif all(hasattr(base_attn, n) for n in ["query", "key", "value"]):
398
+ # BERT - ViT
399
+ q = base_attn.query(hidden_states)
400
+ k = base_attn.key(hidden_states)
401
+ v = base_attn.value(hidden_states)
402
+ out_proj = getattr(base_attn, "dense", None) # ou output.dense
403
+ else:
404
+ raise ValueError("Unsupported attention module: cannot find projections.")
405
+ # Ensure stability
406
+ q = ensure_stability(q, min_val=-1e4, max_val=1e4)
407
+ k = ensure_stability(k, min_val=-1e4, max_val=1e4)
408
+ v = ensure_stability(v, min_val=-1e4, max_val=1e4)
409
+ return q, k, v, out_proj
410
+
411
+ def _process_self_attn(
412
+ self,
413
+ hidden_states: torch.Tensor,
414
+ attention_mask: Optional[torch.Tensor],
415
+ kwargs,
416
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[DynamicCache], int]:
417
+ """Process the self-attention part (with truncation)."""
418
+ if self.max_self_attn_length: # Not needed for SWA (nonparam memorize context)
419
+ hidden_states, attention_mask = truncate_attention_mask(
420
+ hidden_states, attention_mask, self.max_self_attn_length
421
+ )
422
+
423
+ if kwargs.get("position_embeddings", None) is not None:
424
+ cos, sin = kwargs["position_embeddings"]
425
+ cos = cos[:, -self.max_self_attn_length :]
426
+ sin = sin[:, -self.max_self_attn_length :]
427
+ kwargs["position_embeddings"] = (cos, sin)
428
+
429
+ if isinstance(kwargs.get("past_key_value", None), DynamicCache):
430
+ # cache management
431
+ if (
432
+ len(kwargs["past_key_value"]) > self.layer_idx
433
+ and self.layer_idx == 0
434
+ ):
435
+ kwargs["past_key_value"].crop(self.max_self_attn_length - 1)
436
+
437
+ # Ensure attention mask is of the correct dtype and device
438
+ if attention_mask is not None:
439
+ attention_mask = attention_mask.to(
440
+ dtype=hidden_states.dtype, device=hidden_states.device
441
+ )
442
+ # Standard attention (mask and rotation is applied inside)
443
+ base_attn_outputs = self.base_attn(
444
+ hidden_states,
445
+ attention_mask=attention_mask,
446
+ **kwargs,
447
+ )
448
+
449
+ if isinstance(base_attn_outputs, tuple):
450
+ if len(base_attn_outputs) == 3:
451
+ o_base, attn_weights, present_key_value = base_attn_outputs
452
+ expected_attn_mode = 3
453
+ elif len(base_attn_outputs) == 2:
454
+ o_base, attn_weights = base_attn_outputs
455
+ present_key_value, expected_attn_mode = None, 2
456
+ else:
457
+ raise ValueError(
458
+ f"Unexpected number of outputs from base_attn: {len(base_attn_outputs)}"
459
+ )
460
+ else:
461
+ o_base = base_attn_outputs
462
+ attn_weights, present_key_value, expected_attn_mode = None, None, 1
463
+ # Ensure stability
464
+ o_base = ensure_stability(o_base, min_val=-1e4, max_val=1e4)
465
+ return o_base, attn_weights, present_key_value, expected_attn_mode
466
+
467
+ def _prepare_attn_mixin(
468
+ self,
469
+ o_lin: torch.Tensor,
470
+ o_base: torch.Tensor,
471
+ tensor_dtype: torch.dtype,
472
+ eps: float = 1e-5,
473
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
474
+ """Prepare linear attn for mixing with self attn."""
475
+ # Force cast typing, shape : [b n (h d)]
476
+ o_lin = o_lin.to(tensor_dtype)
477
+ o_base = o_base.to(tensor_dtype)
478
+ # feature scaling
479
+ if self.base_scale_attn:
480
+ scaler = o_base.pow(2).mean(dim=-1, keepdim=True).add(eps).sqrt()
481
+ o_lin = scaler * o_lin
482
+ return o_lin, o_base
483
+
484
+ def _apply_mag(
485
+ self, linear_attention: torch.Tensor, softmax_attention: torch.Tensor
486
+ ) -> torch.Tensor:
487
+ """Apply the MAG strategy"""
488
+ # Left-Padding management
489
+ if linear_attention.shape[1] != softmax_attention.shape[1]:
490
+ left_trunc = min(linear_attention.shape[1], softmax_attention.shape[1])
491
+ linear_attention, softmax_attention = (
492
+ linear_attention[:, -left_trunc:],
493
+ softmax_attention[:, -left_trunc:],
494
+ )
495
+ # NAM : Neural Attention Mixer (with graph forcing)
496
+ mag_weight = torch.tensor(
497
+ self.mag_weight,
498
+ dtype=softmax_attention.dtype,
499
+ device=softmax_attention.device,
500
+ )
501
+ softmax_weighted = (1 - mag_weight) * softmax_attention
502
+ linear_weighted = mag_weight * linear_attention
503
+ if self.cross_gate:
504
+ output_attention = (
505
+ softmax_weighted + linear_weighted + softmax_weighted * linear_weighted
506
+ ) # complex cross product (unlinear interaction)
507
+ else:
508
+ output_attention = softmax_weighted + linear_weighted # classic
509
+
510
+ if torch.allclose(softmax_weighted, output_attention):
511
+ logger.info(
512
+ "[LOG] layer : %s, softmax_weighted and output_attention are close.",
513
+ self.layer_idx,
514
+ )
515
+ # Final output
516
+ return ensure_stability(output_attention, min_val=-1e4, max_val=1e4)
517
+
518
+ def forward(
519
+ self,
520
+ hidden_states: torch.Tensor,
521
+ attention_mask: Optional[torch.Tensor] = None,
522
+ **kwargs,
523
+ ) -> torch.Tensor:
524
+ """Mix linear and self attention forward"""
525
+ device = hidden_states.device
526
+ tensor_dtype = hidden_states.dtype
527
+ self.base_attn.to(device)
528
+
529
+ if self.training:
530
+ kwargs.pop("past_key_value", None)
531
+ kwargs["use_cache"] = False
532
+ elif "use_cache" not in kwargs:
533
+ kwargs.pop("past_key_value", None)
534
+ kwargs["use_cache"] = False
535
+
536
+ kwargs.pop("position_ids", None) # obsolete
537
+
538
+ # Apply shared projections
539
+ q, k, v, out_proj = self._apply_shared_projections(hidden_states)
540
+
541
+ # Apply linear attention to hidden states
542
+ o_lin = self.linear_attn(
543
+ x=[q, k, v], attn_mask=attention_mask, out_proj=out_proj, **kwargs
544
+ )
545
+
546
+ # Process self attn with truncation
547
+ o_base, attn_weights, present_key_value, expected_attn_mode = (
548
+ self._process_self_attn(hidden_states, attention_mask, kwargs)
549
+ )
550
+
551
+ # Prepare output mixing
552
+ o_lin, o_base = self._prepare_attn_mixin(o_lin, o_base, tensor_dtype, eps=1e-5)
553
+
554
+ # Apply Memory as Gate in self-attention (with length management and ablation)
555
+ out = o_base if self.disable_linear_attn else self._apply_mag(o_lin, o_base)
556
+
557
+ # Return output following transformer convention
558
+ if expected_attn_mode == 3:
559
+ return out, attn_weights, present_key_value
560
+ if expected_attn_mode == 2:
561
+ return out, attn_weights
562
+ return out
563
+
564
+ @property
565
+ def is_sliding(self):
566
+ """Check if the base attention contain sliding window attention."""
567
+ return getattr(self.base_attn, "is_sliding", False)
568
+
569
+
570
+ def load_tptt_safetensors(
571
+ repo_or_path: str,
572
+ model: Union[PreTrainedModel, PeftModel],
573
+ subfolder: Optional[str] = None,
574
+ token: Optional[str] = None,
575
+ ) -> Union[PreTrainedModel, PeftModel]:
576
+ """Load Tptt safetensor from LoRA/PEFT weights and adapt keys if needed."""
577
+ # sharding not supported yet (e.g. : -00001-of-00005.safetensors, ...)
578
+ fname = "adapter_model.safetensors"
579
+ # subfolder management
580
+ if subfolder:
581
+ repo_or_path_norm = os.path.normpath(repo_or_path)
582
+ subfolder_norm = os.path.normpath(subfolder)
583
+ if not repo_or_path_norm.endswith(subfolder_norm):
584
+ fname = f"{subfolder}/{fname}" if subfolder else fname
585
+ # Find file path
586
+ if os.path.isdir(repo_or_path):
587
+ path = os.path.join(repo_or_path, fname)
588
+ if not os.path.exists(path):
589
+ return model
590
+ else:
591
+ if fname not in list_repo_files(repo_or_path, token=token):
592
+ return model
593
+ path = hf_hub_download(repo_or_path, fname, token=token)
594
+
595
+ # Load weights from safetensors
596
+ with safe_open(path, framework="pt") as f:
597
+ state_dict = {k: f.get_tensor(k) for k in f.keys()}
598
+
599
+ # Adapt LoRA/Specific keys if needed (add .default if expected by the model)
600
+ def adapt_keys(sd, model):
601
+ model_keys = list(model.state_dict().keys())
602
+ if any(k.startswith("tptt_model.base_model.") for k in model_keys):
603
+ prefix = "tptt_model.base_model."
604
+ elif any(k.startswith("base_model.") for k in model_keys):
605
+ prefix = "base_model."
606
+ else:
607
+ prefix = ""
608
+
609
+ has_base_attn = any(".base_attn." in k for k in model_keys)
610
+
611
+ def adapt_key(k):
612
+ k_ = k if k.startswith(prefix) else prefix + k
613
+ # first, verify and modify base_attn (LiZA)
614
+ if ".base_attn." in k_ and not has_base_attn:
615
+ k_ = k_.replace(".base_attn.", ".")
616
+ # change LoRA if needed
617
+ if (
618
+ k_.endswith("lora_A.weight") or k_.endswith("lora_B.weight")
619
+ ) and k_.replace(".weight", ".default.weight") in model_keys:
620
+ k_ = k_.replace(".weight", ".default.weight")
621
+ return k_
622
+
623
+ return {adapt_key(k): v for k, v in sd.items()}
624
+
625
+ state_dict = adapt_keys(state_dict, model)
626
+
627
+ # Cast tensors to the expected dtype of the model parameters
628
+ model_state_dict = model.state_dict()
629
+ for k, v in state_dict.items():
630
+ if k in model_state_dict:
631
+ expected_dtype = model_state_dict[k].dtype
632
+ if v.dtype != expected_dtype:
633
+ state_dict[k] = v.to(expected_dtype)
634
+
635
+ logger.info("Input LoRA/Specific keys: %s", [k for k in state_dict.keys()])
636
+
637
+ # Load into model
638
+ missing, unexpected = model.load_state_dict(state_dict, strict=False, assign=True)
639
+ missing_lora = [k for k in missing if "lora" in k]
640
+ if missing_lora:
641
+ logger.warning("Missing keys: %s", missing_lora)
642
+ if unexpected:
643
+ logger.warning("Unexpected keys: %s", unexpected)
644
+ return model
645
+
646
+
647
+ def get_tptt_model( # pylint: disable=too-many-arguments, too-many-positional-arguments
648
+ model: nn.Module,
649
+ base_config: PretrainedConfig, # ou LlamaConfig, MistralConfig, etc.
650
+ linear_cache: Optional[LCache] = None,
651
+ liza_attention: nn.Module = LiZAttention,
652
+ target_modules_names: Optional[list[str]] = None,
653
+ operator_mode: str = "delta_rule",
654
+ use_linear_checkpoint: bool = False,
655
+ recurrent_config: Optional[Dict[str, Any]] = None,
656
+ base_scale_attn: bool = False,
657
+ mag_weight: float = 0.5,
658
+ cross_gate: bool = False,
659
+ max_chunk_size: int = 64,
660
+ linear_precision: torch.dtype = torch.float32,
661
+ max_self_attn_length: Optional[int] = None, # unnecessary
662
+ padding_side: str = "right", # for tokenizer
663
+ bidirectional: bool = False, # if True, use bidirectional attention
664
+ pooling_config: Optional[Dict[str, Any]] = None,
665
+ **kwargs, # quickfix unexpected arguments
666
+ ) -> Tuple[PreTrainedModel, LCache]:
667
+ """Replace target modules in a model with LiZAttention."""
668
+ if target_modules_names is None:
669
+ target_modules_names = ["attn", "self_attn", "attention"]
670
+ # Find target modules by suffix (e.g., "attn", "attention")
671
+ target_modules_names = [
672
+ name
673
+ for name, _ in model.named_modules()
674
+ if any(name.endswith(suffix) for suffix in target_modules_names)
675
+ and not any(f".{suffix}." in name for suffix in target_modules_names)
676
+ ]
677
+ if not target_modules_names:
678
+ raise ValueError(
679
+ f"Target modules '{target_modules_names}' not found in the model."
680
+ )
681
+ # Prepare recurrent config
682
+ linear_cache = linear_cache or LCache()
683
+ # Inject LiZAttention into the model
684
+ for name, _ in model.named_modules():
685
+ if name in target_modules_names:
686
+ parent = model
687
+ *path, last = name.split(".")
688
+ for p in path:
689
+ parent = getattr(parent, p)
690
+ layer_idx = extract_layer_idx(name)
691
+ setattr(
692
+ parent,
693
+ last,
694
+ liza_attention(
695
+ getattr(parent, last),
696
+ layer_idx=layer_idx,
697
+ base_config=base_config,
698
+ linear_cache=linear_cache,
699
+ operator_mode=operator_mode,
700
+ use_linear_checkpoint=use_linear_checkpoint,
701
+ recurrent_config=recurrent_config,
702
+ max_self_attn_length=max_self_attn_length,
703
+ base_scale_attn=base_scale_attn,
704
+ mag_weight=mag_weight,
705
+ cross_gate=cross_gate,
706
+ max_chunk_size=max_chunk_size,
707
+ linear_precision=linear_precision,
708
+ padding_side=padding_side,
709
+ bidirectional=bidirectional,
710
+ pooling_config=pooling_config,
711
+ ),
712
+ )
713
+ return model, linear_cache
714
+
715
+
716
+ def save_tptt_safetensors(model, path: str, name: str = "adapter_model.safetensors"):
717
+ """Save trainable LoRA/Specific weights and adapting key names"""
718
+ # 1. Get the full state_dict
719
+ all_sd = model.state_dict()
720
+
721
+ # 2. Identify trainable parameter names (usually only LoRA/PEFT adapters)
722
+ trainable_keys = [
723
+ name for name, param in model.named_parameters() if param.requires_grad
724
+ ] # Also, you can manually select specific keys in model after load
725
+
726
+ # 3. Filter and adapt the keys (Remove custom model encapsulation info)
727
+ to_save = {
728
+ k.replace("tptt_model.", "").replace("base_model.", ""): all_sd[k]
729
+ for k in trainable_keys
730
+ }
731
+
732
+ # 4. Save the filtered adapters to a safetensors file
733
+ if to_save:
734
+ os.makedirs(os.path.dirname(path), exist_ok=True)
735
+ # sharding not supported yet (e.g. : -00001-of-00005.safetensors, ...)
736
+ save_file(to_save, os.path.join(path, name))
737
+
738
+
739
+ class TpttModel(PreTrainedModel):
740
+ """
741
+ TPTT model wrapper with linear attention (LiZA) and LoRA support.
742
+ Handles only architecture and weights.
743
+ """
744
+
745
+ config_class = TpttConfig
746
+
747
+ def __init__(
748
+ self,
749
+ config: TpttConfig,
750
+ **kwargs,
751
+ ):
752
+ """
753
+ Initialize TpttModel with a given config and backbone.
754
+ Injects LiZA attention modules into the backbone.
755
+ """
756
+ super().__init__(config, **kwargs)
757
+ repo_or_path = getattr(config, "_base_path", None) or config._name_or_path
758
+
759
+ # 1. Load backbone (with subfolder management) :
760
+ kwargs_bb = kwargs.copy()
761
+ if config.base_model_subfolder is not None:
762
+ kwargs_bb["subfolder"] = config.base_model_subfolder
763
+ else:
764
+ kwargs_bb.pop("subfolder", None)
765
+
766
+ if config.model_task == "causal_lm":
767
+ tptt_model = AutoModelForCausalLM.from_pretrained(
768
+ config.base_model_name, **kwargs_bb
769
+ )
770
+ else:
771
+ tptt_model = AutoModel.from_pretrained(config.base_model_name, **kwargs_bb)
772
+
773
+ # 2. Inject LiZA attention
774
+ self.linear_cache = LCache()
775
+ tptt_model, self.linear_cache = get_tptt_model(
776
+ tptt_model, config, self.linear_cache, **config.to_dict()
777
+ )
778
+
779
+ # 3. Apply LoRA/Specific if present and configured
780
+ if config.lora_config is not None:
781
+ lora_config_obj = LoraConfig(**config.lora_config)
782
+ tptt_model = get_peft_model(tptt_model, lora_config_obj)
783
+ else:
784
+ # Doesn't work if quantization is applied !
785
+ tptt_model = set_trainable_parameters(tptt_model)
786
+
787
+ # 4. Load safetensor if tptt/peft adaptor in repo
788
+ if repo_or_path:
789
+ tptt_model = load_tptt_safetensors(
790
+ repo_or_path,
791
+ tptt_model,
792
+ subfolder=kwargs.get("subfolder", None),
793
+ token=kwargs.get("token", None),
794
+ )
795
+ self.tptt_model = tptt_model
796
+
797
+ def forward(
798
+ self,
799
+ input_ids: Optional[torch.LongTensor] = None,
800
+ attention_mask: Optional[torch.Tensor] = None,
801
+ labels: Optional[torch.LongTensor] = None,
802
+ **kwargs,
803
+ ):
804
+ """Forward pass. All arguments are passed to the underlying base model."""
805
+ if self.training:
806
+ kwargs["use_cache"] = False
807
+ kwargs.pop("num_items_in_batch", None)
808
+ elif "use_cache" not in kwargs: # evaluation
809
+ kwargs.pop("num_items_in_batch", None)
810
+ kwargs["use_cache"] = False
811
+ return self.tptt_model(
812
+ input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs
813
+ )
814
+
815
+ def generate(self, *args, **kwargs):
816
+ """Delegate the generate call to the backbone model, which supports generation"""
817
+ return self.tptt_model.generate(*args, **kwargs)
818
+
819
+ def save_pretrained(self, path: str, **kwargs):
820
+ """Save model weights, config, and source code to the given path."""
821
+ # 0. Save complete tptt config (with or without LoRA)
822
+ super().save_pretrained(path, **kwargs) # pylint: disable=no-member
823
+ self._adjust_save_strategy(path, **kwargs)
824
+ # 1. Save true weights and adapte keys
825
+ save_tptt_safetensors(self, path)
826
+ # 2. Copy Python files for trust_remote_code
827
+ self._copy_source_files(path, **kwargs)
828
+
829
+ def _adjust_save_strategy(self, path: str, **kwargs):
830
+ """Re-adapt/remove the weight safetensor and saved adapter config"""
831
+ if isinstance(self.tptt_model, PeftModel):
832
+ self.tptt_model.save_pretrained(path, **kwargs)
833
+ safetensor_path = os.path.join(path, "model.safetensors")
834
+ if os.path.exists(safetensor_path):
835
+ os.remove(safetensor_path)
836
+ adapter_path = os.path.join(path, "adapter_config.json")
837
+ if os.path.exists(adapter_path):
838
+ os.remove(adapter_path)
839
+
840
+ def _copy_source_files(self, target_path: str, **kwargs):
841
+ """Copy all .py files from package directory for trust_remote_code."""
842
+ src_dir = os.path.dirname(os.path.abspath(__file__))
843
+ dst_dir = (
844
+ f"./{str(Path(target_path).parts[0])}"
845
+ if kwargs.get("subfolder", False)
846
+ else target_path
847
+ )
848
+ for fname in os.listdir(src_dir):
849
+ if fname.endswith(".py"):
850
+ src = os.path.join(src_dir, fname)
851
+ dst = os.path.join(dst_dir, fname)
852
+ shutil.copy2(src, dst)
853
+
854
+ def retie_lm_after_load(self, **kwargs):
855
+ """Re-link lm_head after loading external weights."""
856
+ embed_lm = find_embedding_lm(self.tptt_model)
857
+ if embed_lm is not None and hasattr(self.tptt_model, "lm_head"):
858
+ if self.tptt_model.lm_head is None: # ensure lm_head exists
859
+ self.tptt_model.lm_head = nn.Linear(
860
+ embed_lm.weight.shape[1], embed_lm.weight.shape[0], bias=False
861
+ )
862
+ if kwargs.get("tie_word_embeddings", True):
863
+ self.tptt_model.lm_head.weight = embed_lm.weight # share weights
864
+ logger.info("Weights of lm_head have been shared with embedding.")
865
+ else:
866
+ self.tptt_model.lm_head.weight = nn.Parameter(embed_lm.weight.clone())
867
+ logger.info("Weights of lm_head have been cloned from the embedding.")
868
+
869
+ @classmethod
870
+ def from_pretrained(cls, pretrained_model_name_or_path=None, *model_args, **kwargs):
871
+ """Custom from_pretrained that accepts the standard positional argument"""
872
+ config = kwargs.pop("config", None)
873
+ repo_or_path = (
874
+ pretrained_model_name_or_path
875
+ or kwargs.pop("pretrained_model_name_or_path", None)
876
+ or kwargs.pop("repo_or_path", None)
877
+ or (getattr(config, "_base_path", None) if config else None)
878
+ or (getattr(config, "_name_or_path", None) if config else None)
879
+ )
880
+
881
+ if config is None and repo_or_path is not None:
882
+ config = AutoConfig.from_pretrained(repo_or_path, **kwargs)
883
+ model = cls(config, *model_args, **kwargs)
884
+ model.retie_lm_after_load(**kwargs)
885
+ return model
886
+
887
+
888
+ TpttModel.register_for_auto_class("AutoModelForCausalLM")
889
+
890
+
891
+ class LinearAttentionOp(nn.Module):
892
+ """Base class for linear attention operators."""
893
+
894
+ def __init__(
895
+ self,
896
+ layer_idx: int,
897
+ operator_mode: str = "delta_rule",
898
+ use_linear_checkpoint: bool = False,
899
+ recurrent_config: Optional[dict] = None,
900
+ max_chunk_size: int = 64,
901
+ linear_cache: Optional[LCache] = None,
902
+ linear_precision: torch.dtype = torch.float32,
903
+ ):
904
+ super().__init__()
905
+ self.layer_idx = layer_idx
906
+ if recurrent_config is None:
907
+ operator_mode = "delta_rule" # force default operator mode if no config
908
+ recurrent_config = {
909
+ "order": 1,
910
+ "gate_type": "k",
911
+ "linear": True,
912
+ "trick": "derivative",
913
+ }
914
+ self.operator_mode = operator_mode
915
+ self.use_linear_checkpoint = use_linear_checkpoint
916
+
917
+ self.order = recurrent_config["order"]
918
+ self.gate_type = recurrent_config["gate_type"]
919
+ self.linear = recurrent_config["linear"]
920
+ self.trick = recurrent_config["trick"]
921
+
922
+ self.max_chunk_size = max_chunk_size
923
+ self.linear_cache = linear_cache or LCache()
924
+ self.linear_precision = linear_precision
925
+
926
+ def compute_gate(self, beta: Tuple[torch.Tensor]) -> torch.Tensor:
927
+ """
928
+ Compute the gating tensor according to the gate_type.
929
+ """
930
+ if self.gate_type == "k":
931
+ return torch.clamp(beta[0], min=1e-6, max=1 - 1e-6)
932
+ if self.gate_type == "v":
933
+ return torch.clamp(beta[1], min=1e-6, max=1 - 1e-6)
934
+ if self.gate_type == "kv":
935
+ return torch.clamp(beta[0] * beta[1], min=1e-6, max=1 - 1e-6)
936
+ raise ValueError(f"Unsupported gate_type: {self.gate_type}")
937
+
938
+ def get_cache(self, use_cache: bool) -> Tuple[
939
+ Optional[torch.Tensor],
940
+ Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
941
+ ]:
942
+ """
943
+ Retrieve recurrent state and qkv buffers from the cache.
944
+ """
945
+ if not use_cache:
946
+ return None, None
947
+ last_state = self.linear_cache[self.layer_idx]
948
+ if last_state is not None:
949
+ recurrent_state = last_state.get("recurrent_state", None)
950
+ qkv_buffers = last_state.get("qkv", None)
951
+ else:
952
+ recurrent_state = None
953
+ qkv_buffers = None
954
+ return recurrent_state, qkv_buffers
955
+
956
+ def save_cache(
957
+ self,
958
+ use_cache: bool,
959
+ q: torch.Tensor,
960
+ k: torch.Tensor,
961
+ v: torch.Tensor,
962
+ gate: torch.Tensor,
963
+ state: torch.Tensor,
964
+ ) -> None:
965
+ """
966
+ Save the recurrent state and qkv buffers to the cache.
967
+ """
968
+ if not use_cache:
969
+ return
970
+ if self.order > 1:
971
+ qkv_buffers = (
972
+ q[:, :, -(self.order - 1) :, :],
973
+ k[:, :, -(self.order - 1) :, :],
974
+ v[:, :, -(self.order - 1) :, :],
975
+ gate[:, :, -(self.order - 1) :, :],
976
+ )
977
+ else:
978
+ qkv_buffers = None
979
+ self.linear_cache.update(self.layer_idx, recurrent_state=state, qkv=qkv_buffers)
980
+
981
+ def forward(
982
+ self,
983
+ q: torch.Tensor,
984
+ k: torch.Tensor,
985
+ v: torch.Tensor,
986
+ beta: Union[Tuple[torch.Tensor], torch.Tensor],
987
+ **kwargs,
988
+ ) -> torch.Tensor:
989
+ """
990
+ Forward pass for the attention operator.
991
+ """
992
+ # Ensure linear_precision for numerical stability (float32)
993
+ q, k, v = [x.to(self.linear_precision) for x in (q, k, v)]
994
+ if isinstance(beta, (tuple, list)):
995
+ beta = tuple(b.to(self.linear_precision) for b in beta)
996
+ else:
997
+ beta = beta.to(self.linear_precision)
998
+
999
+ gate = self.compute_gate(beta)
1000
+
1001
+ # Retrieve cache if needed
1002
+ use_cache = kwargs.get("use_cache", False)
1003
+ use_checkpoint = not (use_cache) and self.use_linear_checkpoint
1004
+ recurrent_state, qkvb = self.get_cache(use_cache)
1005
+
1006
+ if qkvb is not None and qkvb[0].shape == q.shape:
1007
+ q = torch.cat([qkvb[0].to(q.device), q], dim=2).to(self.linear_precision)
1008
+ k = torch.cat([qkvb[1].to(q.device), k], dim=2).to(self.linear_precision)
1009
+ v = torch.cat([qkvb[2].to(q.device), v], dim=2).to(self.linear_precision)
1010
+ gate = torch.cat([qkvb[3].to(q.device), gate], dim=2).to(
1011
+ self.linear_precision
1012
+ )
1013
+
1014
+ output, state = self.chunk_delta_product_forward(
1015
+ q,
1016
+ k,
1017
+ v,
1018
+ gate,
1019
+ self.max_chunk_size,
1020
+ n=self.order,
1021
+ trick=self.trick,
1022
+ linear=self.linear,
1023
+ initial_state=recurrent_state,
1024
+ use_checkpoint=use_checkpoint,
1025
+ linear_precision=self.linear_precision,
1026
+ )
1027
+
1028
+ # Save cache if needed
1029
+ self.save_cache(use_cache, q, k, v, gate, state)
1030
+
1031
+ return output
1032
+
1033
+ @staticmethod
1034
+ def chunk_delta_product_forward(
1035
+ query: torch.Tensor,
1036
+ key: torch.Tensor,
1037
+ value: torch.Tensor,
1038
+ beta_gate: torch.Tensor,
1039
+ chunk_size: int,
1040
+ n: int = 1,
1041
+ trick: str = "derivative",
1042
+ linear: bool = True,
1043
+ initial_state: Optional[torch.Tensor] = None,
1044
+ use_checkpoint: bool = True,
1045
+ linear_precision: torch.dtype = torch.float32,
1046
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1047
+ """
1048
+ Chunkwise parallel implementation https://arxiv.org/abs/2406.06484
1049
+ For each chunk, processes chunk_size * n_orders steps (virtual tokens) in order.
1050
+ """
1051
+
1052
+ # --- Main chunk_delta_product_forward logic ---
1053
+
1054
+ batch_size, num_heads, seq_len, head_dim = query.shape
1055
+ chunk_size = get_valid_chunk_size(seq_len, chunk_size)
1056
+ num_chunks = seq_len // chunk_size
1057
+
1058
+ query_n = query if n == 1 else expand_virtual_tokens(query, n, trick)
1059
+ key_n = key if n == 1 else expand_virtual_tokens(key, n, trick)
1060
+ value_n = value if n == 1 else expand_virtual_tokens(value, n, trick)
1061
+ beta_n = beta_gate if n == 1 else expand_virtual_tokens(beta_gate, n, trick)
1062
+
1063
+ q_chunks = chunk_sequence(query_n, num_chunks, chunk_size * n)
1064
+ k_chunks = chunk_sequence(key_n, num_chunks, chunk_size * n)
1065
+ v_chunks = chunk_sequence(value_n, num_chunks, chunk_size * n)
1066
+ beta_chunks = chunk_sequence(beta_n, num_chunks, chunk_size * n)
1067
+
1068
+ k_beta = k_chunks * beta_chunks
1069
+ v_beta = v_chunks * beta_chunks
1070
+
1071
+ householder = -(k_beta @ k_chunks.transpose(-2, -1)).tril(-1)
1072
+ householder = ensure_stability(householder, min_val=-1e4, max_val=1e4)
1073
+
1074
+ # size : N = chunk_size * n
1075
+ inv_hh = fast_invert_matrix(householder, dtype=linear_precision) # [(...),N,N]
1076
+
1077
+ w = ensure_stability(torch.matmul(inv_hh, k_beta), min_val=-1e4, max_val=1e4)
1078
+ u = ensure_stability(torch.matmul(inv_hh, v_beta), min_val=-1e4, max_val=1e4)
1079
+
1080
+ state_shape = (batch_size, num_heads, n, head_dim, head_dim)
1081
+ if initial_state is not None and initial_state.shape == state_shape:
1082
+ state = initial_state.to(device=query.device, dtype=linear_precision)
1083
+ else:
1084
+ state = torch.full(
1085
+ state_shape,
1086
+ fill_value=1e-6, # stability if unlinear activation
1087
+ device=query.device,
1088
+ dtype=linear_precision,
1089
+ )
1090
+
1091
+ output, final_state = sequential_delta_product_scan(
1092
+ q_chunks.to(dtype=linear_precision),
1093
+ w.to(dtype=linear_precision),
1094
+ u.to(dtype=linear_precision),
1095
+ n,
1096
+ linear,
1097
+ chunk_size,
1098
+ state.to(dtype=linear_precision),
1099
+ linear_precision=linear_precision,
1100
+ use_checkpoint=use_checkpoint,
1101
+ )
1102
+
1103
+ idx_last_order = torch.arange(chunk_size, device=output.device) * n + (n - 1)
1104
+ output = output[:, :, :, idx_last_order, :] # [B, H, num_chunks, chunk_size, D]
1105
+ output = output.reshape(batch_size, num_heads, seq_len, head_dim)
1106
+
1107
+ return output.to(dtype=linear_precision), final_state.to(dtype=linear_precision)
1108
+
1109
+
1110
+ def sequential_delta_product_scan(
1111
+ q_chunks: torch.Tensor,
1112
+ w: torch.Tensor,
1113
+ u: torch.Tensor,
1114
+ n_orders: int,
1115
+ linear_activation: bool,
1116
+ current_chunk_size: int,
1117
+ initial_recurrent_state: torch.Tensor,
1118
+ linear_precision: torch.dtype,
1119
+ use_checkpoint: bool,
1120
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1121
+ """
1122
+ DeltaProduct implementation https://arxiv.org/abs/2502.10297
1123
+ Implements the per-token Householder state updates.
1124
+ """
1125
+ batch, head, num_chunks_inner, chunk_n_total, dim = q_chunks.shape
1126
+ output_inner = torch.empty_like(q_chunks)
1127
+ # initial_recurrent_state is H_{last_token_of_prev_chunk, n-1} ([B, H, D, D])
1128
+ h_0_base = initial_recurrent_state[:, :, -1, :, :].clone()
1129
+
1130
+ def process_one_chunk(
1131
+ q_chunk_params: torch.Tensor,
1132
+ w_chunk_params: torch.Tensor,
1133
+ u_chunk_params: torch.Tensor,
1134
+ h_0_base: torch.Tensor,
1135
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1136
+ """
1137
+ Process a single chunk (with per-token state for n_orders > 1).
1138
+ """
1139
+ o_intra_current_chunk = torch.zeros(
1140
+ batch,
1141
+ head,
1142
+ chunk_n_total,
1143
+ dim,
1144
+ device=q_chunk_params.device,
1145
+ dtype=linear_precision,
1146
+ )
1147
+ o_inter_current_chunk = torch.zeros_like(o_intra_current_chunk)
1148
+ current_accumulated_state_per_token = (
1149
+ h_0_base.unsqueeze(2).expand(-1, -1, current_chunk_size, -1, -1).clone()
1150
+ ) # [B, H, current_chunk_size, D, D]
1151
+
1152
+ for step in range(n_orders):
1153
+ idx_virtual_tokens = (
1154
+ torch.arange(current_chunk_size, device=q_chunk_params.device)
1155
+ * n_orders
1156
+ + step
1157
+ )
1158
+ q_s = q_chunk_params[:, :, idx_virtual_tokens, :]
1159
+ w_s = w_chunk_params[:, :, idx_virtual_tokens, :]
1160
+ u_s = u_chunk_params[:, :, idx_virtual_tokens, :]
1161
+
1162
+ state_input_for_this_step = current_accumulated_state_per_token
1163
+
1164
+ ## BLAS/cuBLAS einsum "bhcd,bhcdd->bhcd"
1165
+ k_trans_h_old = (
1166
+ torch.matmul(
1167
+ w_s.unsqueeze(-2),
1168
+ state_input_for_this_step,
1169
+ )
1170
+ .squeeze(-2)
1171
+ .to(dtype=linear_precision)
1172
+ )
1173
+
1174
+ u_val = u_s - k_trans_h_old
1175
+
1176
+ o_inter_current_chunk[:, :, idx_virtual_tokens, :] = (
1177
+ torch.matmul(q_s.unsqueeze(-2), state_input_for_this_step)
1178
+ .squeeze(-2)
1179
+ .to(dtype=linear_precision)
1180
+ )
1181
+
1182
+ ## BLAS/cuBLAS einsum "bhcd,bhcd->bhcd"
1183
+ o_intra_current_chunk[:, :, idx_virtual_tokens, :] = (q_s * u_val).to(
1184
+ dtype=linear_precision
1185
+ )
1186
+
1187
+ outer_product_term = torch.matmul(w_s.unsqueeze(-1), u_val.unsqueeze(-2))
1188
+ new_state_i_per_token = state_input_for_this_step + outer_product_term
1189
+ current_accumulated_state_per_token = new_state_i_per_token.to(
1190
+ dtype=linear_precision
1191
+ )
1192
+ # Return all needed for next chunk
1193
+ return (
1194
+ o_intra_current_chunk,
1195
+ o_inter_current_chunk,
1196
+ current_accumulated_state_per_token[:, :, -1, :, :], # new h_0_base
1197
+ )
1198
+
1199
+ for chunk_idx_inner in range(num_chunks_inner):
1200
+ q_chunk_params = q_chunks[:, :, chunk_idx_inner]
1201
+ w_chunk_params = w[:, :, chunk_idx_inner]
1202
+ u_chunk_params = u[:, :, chunk_idx_inner]
1203
+
1204
+ # Checkpointed call if training
1205
+ call = (
1206
+ partial(checkpoint, use_reentrant=False)
1207
+ if use_checkpoint
1208
+ else lambda f, *a: f(*a)
1209
+ )
1210
+ o_intra, o_inter, h_0_base = call(
1211
+ process_one_chunk,
1212
+ q_chunk_params,
1213
+ w_chunk_params,
1214
+ u_chunk_params,
1215
+ h_0_base,
1216
+ )
1217
+ if not linear_activation: # unlinear activation between chunks
1218
+ h_0_base = unlinear_activation(h_0_base).to(dtype=linear_precision)
1219
+ output_inner[:, :, chunk_idx_inner] = o_intra + o_inter
1220
+
1221
+ return output_inner, h_0_base
1222
+
1223
+
1224
+ def unlinear_activation(x: torch.Tensor, scale: float = 2.0) -> torch.Tensor:
1225
+ """Unlinear activation between chunk"""
1226
+ x_n = x.norm(p=2, dim=-1, keepdim=True) + 1e-6
1227
+ x_gelu = F.gelu(scale * x / x_n, approximate="tanh") # pylint: disable=not-callable
1228
+ return (x / scale) * x_gelu
1229
+
1230
+
1231
+ def chunk_sequence(x: torch.Tensor, num_chunks: int, chunk_size: int) -> torch.Tensor:
1232
+ """Splits [B, H, S, D] to [B, H, num_chunks, chunk_size, D]"""
1233
+ batch_size, num_heads, _, head_dim = x.shape
1234
+ return x.reshape(batch_size, num_heads, num_chunks, chunk_size, head_dim)
1235
+
1236
+
1237
+ def expand_virtual_tokens(
1238
+ x: torch.Tensor, n: int, mode: str = "derivative"
1239
+ ) -> torch.Tensor:
1240
+ """Expand tokens into 'n' virtual tokens using the selected trick."""
1241
+ batch_size, num_heads, seq_len, head_dim = x.shape
1242
+ device, dtype = x.device, x.dtype
1243
+
1244
+ def derivative_expand(x: torch.Tensor) -> torch.Tensor:
1245
+ """Expand tokens using the derivative trick."""
1246
+ x_pad = torch.cat(
1247
+ [
1248
+ torch.zeros(
1249
+ batch_size, num_heads, n - 1, head_dim, device=device, dtype=dtype
1250
+ ),
1251
+ x,
1252
+ ],
1253
+ dim=2,
1254
+ )
1255
+ coeffs = torch.tensor(
1256
+ [(-1) ** k * math.comb(n - 1, k) for k in range(n)],
1257
+ device=device,
1258
+ dtype=dtype,
1259
+ )
1260
+ coeffs /= coeffs.norm(p=1)
1261
+ return (
1262
+ (x_pad.unfold(2, n, 1) * coeffs.view(1, 1, 1, 1, n))
1263
+ .flip(-1)
1264
+ .permute(0, 1, 2, 4, 3)
1265
+ .reshape(batch_size, num_heads, seq_len * n, head_dim)
1266
+ )
1267
+
1268
+ def rotative_expand(x: torch.Tensor) -> torch.Tensor:
1269
+ """Expand tokens using the rotative trick."""
1270
+ d_parity = head_dim // 2
1271
+ angles = torch.arange(n, device=device, dtype=dtype) * (2 * math.pi / n)
1272
+ cos = torch.cos(angles).view(1, 1, 1, n, 1)
1273
+ sin = torch.sin(angles).view(1, 1, 1, n, 1)
1274
+ if head_dim % 2:
1275
+ x_pairs = x[..., :-1].view(batch_size, num_heads, seq_len, d_parity, 2)
1276
+ else:
1277
+ x_pairs = x.view(batch_size, num_heads, seq_len, d_parity, 2)
1278
+ x_pairs = x_pairs.unsqueeze(3).expand(
1279
+ batch_size, num_heads, seq_len, n, d_parity, 2
1280
+ )
1281
+ x0, x1 = x_pairs[..., 0], x_pairs[..., 1]
1282
+ x0r = x0 * cos - x1 * sin
1283
+ x1r = x0 * sin + x1 * cos
1284
+ rot = torch.stack([x0r, x1r], -1).reshape(
1285
+ batch_size, num_heads, seq_len, n, d_parity * 2
1286
+ )
1287
+ if head_dim % 2:
1288
+ last = (
1289
+ x[..., -1]
1290
+ .unsqueeze(-1)
1291
+ .unsqueeze(3)
1292
+ .expand(batch_size, num_heads, seq_len, n, 1)
1293
+ )
1294
+ rot = torch.cat([rot, last], -1)
1295
+ return rot.reshape(batch_size, num_heads, seq_len * n, head_dim)
1296
+
1297
+ if mode == "derivative":
1298
+ return derivative_expand(x)
1299
+ if mode == "rotative":
1300
+ return rotative_expand(x)
1301
+ if mode == "combined":
1302
+ return (derivative_expand(x) + rotative_expand(x)) / 2
1303
+ raise ValueError(f"Unknown mode: {mode}")
1304
+
1305
+
1306
+ def extract_layer_idx(module_name: str) -> int:
1307
+ """Extract the layer index from a module name string."""
1308
+ match = re.search(r"\.(\d+)\.", module_name)
1309
+ if match:
1310
+ return int(match.group(1))
1311
+ return -1
1312
+
1313
+
1314
+ def find_embedding_lm(module: nn.Module) -> Optional[nn.Module]:
1315
+ """Find the embedding weight in a model module."""
1316
+ for _, child in module.named_modules():
1317
+ if hasattr(child, "embed_tokens") and hasattr(child.embed_tokens, "weight"):
1318
+ return child.embed_tokens
1319
+ if hasattr(child, "token_embeddings") and hasattr(
1320
+ child.token_embeddings, "weight"
1321
+ ):
1322
+ return child.token_embeddings
1323
+ return None
1324
+
1325
+
1326
+ def set_trainable_parameters(
1327
+ model: PreTrainedModel, trainable_patterns: List[str] = None
1328
+ ) -> PreTrainedModel:
1329
+ """Freeze model parameters except trainable_patterns."""
1330
+ if trainable_patterns is None:
1331
+ trainable_patterns = [
1332
+ "q_proj",
1333
+ "k_proj",
1334
+ "v_proj",
1335
+ "o_proj",
1336
+ "qkv_proj",
1337
+ "out_proj",
1338
+ "c_attn",
1339
+ "c_proj",
1340
+ "query",
1341
+ "key",
1342
+ "value",
1343
+ ]
1344
+
1345
+ for name, param in model.named_parameters():
1346
+ param.requires_grad = any(pattern in name for pattern in trainable_patterns)
1347
+
1348
+ trainable_layers = [n for n, p in model.named_parameters() if p.requires_grad]
1349
+ logger.info("Trainable parameters after freeze: %s", trainable_layers)
1350
+ return model
1351
+
1352
+
1353
+ def ensure_stability(
1354
+ tensor: torch.Tensor, min_val: float = -1e4, max_val: float = 1e4
1355
+ ) -> torch.Tensor:
1356
+ """stability forcing"""
1357
+ dtype = tensor.dtype
1358
+ center = (max_val + min_val) / 2
1359
+ tensor = torch.clamp(tensor, min=min_val, max=max_val)
1360
+ tensor = torch.nan_to_num(tensor, nan=center, posinf=max_val, neginf=min_val)
1361
+ return tensor.to(dtype=dtype)
1362
+
1363
+
1364
+ def apply_linear_attention_mask(
1365
+ attention_mask: torch.Tensor, v: torch.Tensor, padding_side: str = "right"
1366
+ ) -> torch.Tensor:
1367
+ """Extract if padding --> [B,S]"""
1368
+ if attention_mask.dim() == 4 and attention_mask.shape[1] == 1:
1369
+ mask = attention_mask.diagonal(dim1=-2, dim2=-1).squeeze(1)
1370
+ else:
1371
+ mask = attention_mask.squeeze(
1372
+ dim=tuple(
1373
+ i
1374
+ for i in range(1, attention_mask.dim())
1375
+ if attention_mask.shape[i] == 1
1376
+ )
1377
+ )
1378
+ # Ensure cast to the same dtype as v and convert to binary mask
1379
+ if not (
1380
+ mask.dtype == torch.bool
1381
+ or (
1382
+ mask.dtype in [torch.uint8, torch.int32, torch.int64]
1383
+ and mask.max() <= 1
1384
+ and mask.min() >= 0
1385
+ )
1386
+ ):
1387
+ mask = (mask >= 0).to(v.dtype) # [-inf, 0, 0, -inf] --> [0, 1, 1, 0]
1388
+ else:
1389
+ mask = mask.to(v.dtype)
1390
+ # mask is [batch, seq] --> Broadcast to v [batch, seq, (...)]
1391
+ if padding_side == "left":
1392
+ mask = mask[:, -v.shape[-2] :][(...,) + (None,) * (v.dim() - 2)]
1393
+ else: # right padding
1394
+ mask = mask[:, : v.shape[-2]][(...,) + (None,) * (v.dim() - 2)]
1395
+ return v * mask
1396
+
1397
+
1398
+ def truncate_attention_mask(
1399
+ hidden_states: torch.Tensor, attention_mask: torch.Tensor, max_length: int
1400
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1401
+ """Truncate hidden_states and attention_mask to the last window of size max_length"""
1402
+ seq_dim = 1 # convention: (batch, seq, ...)
1403
+ seq_len = hidden_states.shape[seq_dim]
1404
+ if seq_len > max_length:
1405
+ hidden_states = hidden_states.narrow(seq_dim, seq_len - max_length, max_length)
1406
+ if attention_mask is not None:
1407
+ # mask [batch, seq]
1408
+ if attention_mask.dim() == 2:
1409
+ attention_mask = attention_mask[:, -max_length:]
1410
+ # mask [batch, seq, seq]
1411
+ elif attention_mask.dim() == 3:
1412
+ attention_mask = attention_mask[:, -max_length:, -max_length:]
1413
+ # mask [batch, 1, seq, seq]
1414
+ elif attention_mask.dim() == 4 and attention_mask.shape[1] == 1:
1415
+ attention_mask = attention_mask[:, :, -max_length:, -max_length:]
1416
+ else:
1417
+ raise ValueError(
1418
+ "No dimension in attention_mask matches sequence length of hidden_states."
1419
+ )
1420
+ return hidden_states, attention_mask
1421
+
1422
+
1423
+ def fast_invert_matrix(
1424
+ tri_tensor: torch.Tensor, dtype: torch.dtype = torch.float32
1425
+ ) -> torch.Tensor:
1426
+ """Equivalent to vectorized forward substitution applied to the identity matrix."""
1427
+ tri_tensor = tri_tensor.to(dtype=dtype).clone()
1428
+ chunk_size = tri_tensor.shape[-1]
1429
+
1430
+ for i in range(1, chunk_size):
1431
+ tri_tensor[..., i, :i] = tri_tensor[..., i, :i] + (
1432
+ tri_tensor[..., i, :, None].clone() * tri_tensor[..., :, :i].clone()
1433
+ ).sum(-2)
1434
+
1435
+ tri_tensor = tri_tensor + torch.eye(
1436
+ chunk_size, dtype=dtype, device=tri_tensor.device
1437
+ )
1438
+ return tri_tensor.to(dtype=dtype)
1439
+
1440
+
1441
+ def get_valid_chunk_size(total_l: int, chunk_size: int) -> int:
1442
+ """Return the largest chunk_size <= chunk_size that divides total_l."""
1443
+ for c in range(min(chunk_size, total_l), 0, -1):
1444
+ if total_l % c == 0:
1445
+ return c
1446
+ return 1
1447
+
1448
+
1449
+ ## RARELY
1450
+ def split_qkv(
1451
+ base_attn: nn.Module, qkv: torch.Tensor
1452
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1453
+ """Split the QKV tensor into separate Q, K, and V tensors."""
1454
+ num_q_heads = getattr(base_attn, "num_q_heads", None)
1455
+ num_k_heads = getattr(base_attn, "num_k_heads", None)
1456
+ num_v_heads = getattr(base_attn, "num_v_heads", None)
1457
+ head_dim = getattr(base_attn, "head_dim", None)
1458
+
1459
+ if num_q_heads is None or num_k_heads is None or num_v_heads is None:
1460
+ raise ValueError(
1461
+ "Base attention must have num_q_heads, num_k_heads, and num_v_heads defined."
1462
+ )
1463
+
1464
+ q_len = num_q_heads * head_dim
1465
+ k_len = num_k_heads * head_dim
1466
+ v_len = num_v_heads * head_dim
1467
+
1468
+ q, k, v = torch.split(qkv, [q_len, k_len, v_len], dim=-1)
1469
+ return q, k, v
1470
+
1471
+
1472
+ ## OPTIONAL
1473
+ def match_dim(x: torch.Tensor, dim: int, target_size: int) -> torch.Tensor:
1474
+ """Match the size of tensor x along dimension dim to target_size by interpolation"""
1475
+ src_size = x.shape[dim]
1476
+ if src_size == target_size:
1477
+ return x
1478
+ x = torch.moveaxis(x, dim, -1)
1479
+ shape = x.shape
1480
+ if src_size < target_size:
1481
+ x = x.reshape(-1, 1, src_size)
1482
+ x = F.interpolate(x, size=target_size, mode="linear", align_corners=False)
1483
+ x = x.reshape(*shape[:-1], target_size)
1484
+ else:
1485
+ eye = torch.eye(target_size, src_size, device=x.device, dtype=x.dtype)
1486
+ x = F.linear(x, eye) # pylint: disable=not-callable
1487
+ x = torch.moveaxis(x, -1, dim)
1488
+ return x
1489
+
1490
+
1491
+ def soft_clamp(
1492
+ x: torch.Tensor, min_val: float = 1e-6, max_val: float = 1 - 1e-6
1493
+ ) -> torch.Tensor:
1494
+ """Differentiable clamping for stability"""
1495
+ dtype = x.dtype
1496
+ scale = (max_val - min_val) / 2
1497
+ center = (max_val + min_val) / 2
1498
+ return (torch.tanh((x - center) / scale) * scale + center).to(dtype=dtype)
1499
+
1500
+
1501
+ def describe(x: torch.Tensor, name="tensor") -> None:
1502
+ """Prints the shape, min, max, mean, and std of a tensor."""
1503
+ stats = (x.min(), x.max(), x.mean(), x.std())
1504
+ print(
1505
+ f"{name} shape: {tuple(x.shape)}, "
1506
+ + f"min: {stats[0]:.4g}, max: {stats[1]:.4g}, "
1507
+ + f"mean: {stats[2]:.4g}, std: {stats[3]:.4g}, "
1508
+ + f"dtype: {x.dtype}, device: {x.device}"
1509
+ )
lora_delta_product_m0.5_constant/runs/Aug29_08-00-53_7f724a8e0ba4/events.out.tfevents.1756454454.7f724a8e0ba4.159.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e0162a52dec47246a768fc16693c78a663b93f956e43fee88682b984951c236
3
+ size 69639
lora_delta_product_m0.5_constant/special_tokens_map.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "boi_token": "<start_of_image>",
3
+ "bos_token": {
4
+ "content": "<bos>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ },
10
+ "eoi_token": "<end_of_image>",
11
+ "eos_token": {
12
+ "content": "<eos>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false
17
+ },
18
+ "image_token": "<image_soft_token>",
19
+ "pad_token": {
20
+ "content": "<pad>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false
25
+ },
26
+ "unk_token": {
27
+ "content": "<unk>",
28
+ "lstrip": false,
29
+ "normalized": false,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ }
33
+ }
lora_delta_product_m0.5_constant/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6303ee46fcfa60a90e110d2007271cac11ee3c916f187a23fc6b54122a7f1f45
3
+ size 33384820
lora_delta_product_m0.5_constant/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c
3
+ size 4689074
lora_delta_product_m0.5_constant/tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_tptt.py ADDED
@@ -0,0 +1,1509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=too-many-lines, too-many-arguments, too-many-positional-arguments, too-many-instance-attributes, too-many-locals
2
+
3
+ """
4
+ This module implements the TPTT model with linear attention (LiZA) and LoRA support.
5
+ Author : Fabien FURFARO
6
+ TPTT : Transforming Pretrained Transformers into Titans (https://arxiv.org/abs/2506.17671)
7
+ """
8
+
9
+ import logging
10
+ import math
11
+ import os
12
+ from pathlib import Path
13
+ import re
14
+ import shutil
15
+ from functools import partial
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from einops import rearrange
21
+ from huggingface_hub import hf_hub_download, list_repo_files
22
+ from peft import LoraConfig, PeftModel, get_peft_model
23
+ from safetensors import safe_open
24
+ from safetensors.torch import save_file
25
+ from torch import nn
26
+ from torch.utils.checkpoint import checkpoint
27
+ from transformers import (
28
+ AutoConfig,
29
+ AutoModel,
30
+ AutoModelForCausalLM,
31
+ DynamicCache,
32
+ PreTrainedModel,
33
+ )
34
+ from transformers.configuration_utils import PretrainedConfig
35
+
36
+ from .configuration_tptt import TpttConfig
37
+
38
+ logger = logging.getLogger(__name__) # monitoring
39
+
40
+
41
+ class LCache:
42
+ """Cache for storing intermediate states of linear attention layers."""
43
+
44
+ def __init__(self):
45
+ """Stores per-layer intermediate states: {layer_idx: state_dict}"""
46
+ self.inputs_states: Dict[int, Dict[str, torch.Tensor]] = (
47
+ {}
48
+ ) # recurrent states and qkv buffers
49
+
50
+ def __getitem__(self, layer_idx: int) -> Optional[Dict[str, torch.Tensor]]:
51
+ """Retrieve cached state for a given layer, or None if not present"""
52
+ return self.inputs_states.get(layer_idx, None)
53
+
54
+ def update(self, layer_idx: int, **kwargs):
55
+ """Detach all tensors to avoid retaining computation graphs"""
56
+ detached_kwargs = {
57
+ k: v.detach() if isinstance(v, torch.Tensor) else v
58
+ for k, v in kwargs.items()
59
+ }
60
+ # Update or create the state for the specified layer
61
+ if layer_idx in self.inputs_states:
62
+ self.inputs_states[layer_idx].update(detached_kwargs)
63
+ else:
64
+ self.inputs_states[layer_idx] = detached_kwargs
65
+
66
+ def reset(self):
67
+ """Clear all cached states and reset the token counter"""
68
+ self.inputs_states.clear()
69
+
70
+
71
+ class CausalAvgPool1d(nn.Module):
72
+ """Causal sliding window average (uniform, no shape loss along sequence)"""
73
+
74
+ def __init__(
75
+ self, output_size: int, offsets: tuple[int] = (0, 1, 2), mode: str = "replicate"
76
+ ):
77
+ super().__init__()
78
+ self.offsets = offsets
79
+ self.mode = mode
80
+ self.pool = nn.AdaptiveAvgPool1d(output_size=output_size)
81
+
82
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
83
+ """x: [B, S, F] → [B, S, F → output_size]"""
84
+ x_ = x.transpose(1, 2) # [B, F, S]
85
+ idxs = torch.tensor(self.offsets, device=x.device)
86
+ ksize = idxs.max() - idxs.min() + 1
87
+ w = torch.zeros(ksize, device=x.device, dtype=x.dtype)
88
+ w[idxs - idxs.min()] = 1 / len(self.offsets) # Always uniform weights
89
+ kernel = w.repeat(x_.shape[1], 1).reshape(x_.shape[1], 1, ksize)
90
+ pad_left = -idxs.min().item()
91
+ pad_right = (ksize - 1) - pad_left
92
+ x_pad = F.pad(x_, (pad_left, pad_right), mode=self.mode)
93
+ y = F.conv1d(x_pad, kernel, groups=x_.shape[1]) # pylint: disable=not-callable
94
+ return self.pool(y.transpose(1, 2)) # [B, S, F → output_size]
95
+
96
+
97
+ class LinearAttention(nn.Module):
98
+ """
99
+ Linear multi-head attention layer: [B, S, D] -> [B, S, D]
100
+ Projections + gating + efficient linear attention mechanism (TPTT compatible).
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ hidden_dim: int,
106
+ num_heads: int,
107
+ head_dim: Optional[int] = None,
108
+ num_key_value_heads: Optional[int] = None,
109
+ num_key_value_groups: Optional[int] = None,
110
+ bias: bool = True,
111
+ dropout: Optional[float] = None,
112
+ linear_precision: torch.dtype = torch.float32,
113
+ padding_side: str = "right",
114
+ shared_attn: bool = False, # shared attention
115
+ layer_idx: int = 0,
116
+ operator_mode: str = "delta_rule",
117
+ use_linear_checkpoint: bool = False,
118
+ recurrent_config: Optional[Dict[str, Any]] = None,
119
+ linear_cache: Optional[LCache] = None,
120
+ max_chunk_size: int = 64,
121
+ bidirectional: bool = False, # not used if causal
122
+ pooling_config: Optional[Dict[str, Any]] = None,
123
+ ):
124
+ super().__init__()
125
+ if pooling_config is None:
126
+ pooling_config = {
127
+ "offsets": (0, 1, 2),
128
+ "mode": "replicate",
129
+ }
130
+ self.hidden_dim = hidden_dim
131
+ self.num_heads = num_heads
132
+ self.head_dim = head_dim or hidden_dim // num_heads
133
+ self.num_key_value_heads = num_key_value_heads or num_heads
134
+ self.num_key_value_groups = num_key_value_groups or (
135
+ num_heads // (num_key_value_heads or num_heads)
136
+ )
137
+ self.scaling = self.head_dim**-0.5
138
+ self.linear_precision = linear_precision
139
+ self.padding_side = padding_side
140
+
141
+ self.shared_attn = shared_attn
142
+
143
+ if not shared_attn:
144
+ self.q_proj = nn.Linear(hidden_dim, num_heads * self.head_dim, bias=bias)
145
+ self.k_proj = nn.Linear(
146
+ hidden_dim, self.num_key_value_heads * self.head_dim, bias=bias
147
+ )
148
+ self.v_proj = nn.Linear(
149
+ hidden_dim, self.num_key_value_heads * self.head_dim, bias=bias
150
+ )
151
+ self.out_proj = nn.Linear(num_heads * self.head_dim, hidden_dim, bias=bias)
152
+
153
+ self.dropout = nn.Dropout(dropout) if dropout is not None else None
154
+
155
+ self.linear_operator = LinearAttentionOp(
156
+ layer_idx=layer_idx,
157
+ operator_mode=operator_mode,
158
+ use_linear_checkpoint=use_linear_checkpoint,
159
+ recurrent_config=recurrent_config,
160
+ max_chunk_size=max_chunk_size,
161
+ linear_cache=linear_cache,
162
+ linear_precision=linear_precision,
163
+ )
164
+ self.bidirectional = bidirectional
165
+ # Causal average pooling for gating
166
+ self.pooling_config = pooling_config
167
+ self.pool_g = CausalAvgPool1d(
168
+ output_size=self.head_dim * self.num_key_value_heads, **pooling_config
169
+ )
170
+
171
+ def forward(
172
+ self,
173
+ x: Union[List[torch.Tensor], torch.Tensor],
174
+ attn_mask: Optional[torch.Tensor] = None,
175
+ out_proj: Optional[nn.Module] = None,
176
+ **kwargs: Any,
177
+ ) -> torch.Tensor:
178
+ """
179
+ Forward pass for linear attention. Input shape: [B, S, D], output [B, S, D].
180
+ """
181
+
182
+ if not self.shared_attn:
183
+ hidden_states = x[0] if isinstance(x, (list, tuple)) else x
184
+ # Projections
185
+ q = self.q_proj(hidden_states)
186
+ k = self.k_proj(hidden_states)
187
+ v = self.v_proj(hidden_states)
188
+ out_proj = self.out_proj
189
+ else:
190
+ # Shared attention <=> no projections here
191
+ q, k, v = x[0], x[1], x[2]
192
+ out_proj = self.out_proj if out_proj is None else out_proj
193
+
194
+ # get dtype and device
195
+ final_dtype, final_device = q.dtype, q.device
196
+ # Masking if needed
197
+ if attn_mask is not None:
198
+ v = apply_linear_attention_mask(attn_mask, v, self.padding_side)
199
+
200
+ # Forget and Write Gating for linear attn (abusive term)
201
+ f_g, w_g = self.pool_g(k), self.pool_g(v)
202
+
203
+ # Reshape for multi-head
204
+ q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads)
205
+ k = rearrange(k, "b n (h d) -> b h n d", h=self.num_key_value_heads)
206
+ v = rearrange(v, "b n (h d) -> b h n d", h=self.num_key_value_heads)
207
+
208
+ f_g = rearrange(f_g, "b n (h m) -> b h n m", h=self.num_key_value_heads)
209
+ w_g = rearrange(w_g, "b n (h m) -> b h n m", h=self.num_key_value_heads)
210
+
211
+ # Repeat for GQA
212
+ k = k.repeat_interleave(self.num_key_value_groups, dim=1)
213
+ v = v.repeat_interleave(self.num_key_value_groups, dim=1)
214
+
215
+ f_g = f_g.repeat_interleave(self.num_key_value_groups, dim=1)
216
+ w_g = w_g.repeat_interleave(self.num_key_value_groups, dim=1)
217
+
218
+ ## DeltaNet-style: Silu activation and normalization
219
+ q = F.normalize(F.silu(q), p=2, dim=-1, eps=1e-6)
220
+ k = F.normalize(F.silu(k), p=2, dim=-1, eps=1e-6)
221
+
222
+ ## linear stability part
223
+ v = ensure_stability(v * self.scaling, min_val=-1e4, max_val=1e4)
224
+
225
+ # Apply sigmoid to forget and write gates
226
+ f_g = torch.clamp(torch.sigmoid(f_g), min=1e-6, max=1 - 1e-6)
227
+ w_g = torch.clamp(torch.sigmoid(w_g), min=1e-6, max=1 - 1e-6)
228
+
229
+ # Convert to linear_precision (float32) for numerical stability and get model dtype
230
+ q, k, v, f_g, w_g = (
231
+ x.to(self.linear_precision).contiguous() for x in (q, k, v, f_g, w_g)
232
+ )
233
+ g = (f_g, w_g)
234
+
235
+ # Linear Attention Core, output: [B, H, S, d]
236
+ if self.bidirectional: # Work only with uncausal attention
237
+ # Forward direction
238
+ out_forward = self.linear_operator(q, k, v, g, **kwargs)
239
+ # Backward direction: flip the input sequence on the time dimension (dim=2)
240
+ kwargs_bwd = kwargs.copy()
241
+ kwargs_bwd["use_cache"] = False
242
+ out_backward = self.linear_operator(
243
+ torch.flip(q, dims=[2]),
244
+ torch.flip(k, dims=[2]),
245
+ torch.flip(v, dims=[2]),
246
+ tuple(torch.flip(t, dims=[2]) for t in g),
247
+ **kwargs_bwd,
248
+ )
249
+ # Flip the output back to restore proper order
250
+ out_backward = torch.flip(out_backward, dims=[2])
251
+ # Fusion: here, simple addition
252
+ out = out_forward + out_backward
253
+ else:
254
+ out = self.linear_operator(q, k, v, g, **kwargs)
255
+
256
+ # Merge heads and project: [B, H, S, d] -> [B, S, H*d] -> Out proj
257
+ out = rearrange(out, "b h s d -> b s (h d)")
258
+ # Normalize output (RMS norm). Note: bidirectional compatibility
259
+ out = out / out.pow(2).mean(dim=-1, keepdim=True).add(1e-6).sqrt()
260
+ # Ensure dtype and device consistency
261
+ out = out.to(dtype=final_dtype, device=final_device)
262
+ # Apply output projection
263
+ out = out_proj(out) # [B, S, D]
264
+ out = ensure_stability(out, min_val=-1e4, max_val=1e4)
265
+ # Apply dropout if specified
266
+ if self.dropout is not None:
267
+ out = self.dropout(out)
268
+ return out
269
+
270
+
271
+ class LiZAttention(nn.Module):
272
+ """LiZA Linear Attention module, mixing linear and vanilla attention."""
273
+
274
+ def __init__(
275
+ self,
276
+ base_attn: nn.Module,
277
+ layer_idx: int,
278
+ base_config: PretrainedConfig, # Backbone Config
279
+ linear_cache: Optional[LCache] = None,
280
+ operator_mode: str = "delta_rule",
281
+ use_linear_checkpoint: bool = False,
282
+ recurrent_config: Optional[Dict[str, Any]] = None,
283
+ max_self_attn_length: Optional[int] = None, # unnecessary
284
+ base_scale_attn: bool = False,
285
+ mag_weight: float = 0.5,
286
+ cross_gate: bool = False,
287
+ max_chunk_size: int = 64,
288
+ linear_precision: Union[str, torch.dtype] = "float32",
289
+ padding_side: str = "right", # for tokenizer
290
+ disable_linear_attn: bool = False,
291
+ bidirectional: bool = False, # if True, use bidirectional attention
292
+ pooling_config: Optional[Dict[str, Any]] = None,
293
+ ):
294
+ super().__init__()
295
+ if isinstance(linear_precision, str):
296
+ linear_precision = getattr(torch, linear_precision)
297
+ self.linear_precision = linear_precision
298
+ self.base_attn: nn.Module = base_attn
299
+ self.base_config = base_config
300
+ self.layer_idx = layer_idx
301
+ self.max_self_attn_length = max_self_attn_length
302
+ self.base_scale_attn = base_scale_attn
303
+ self.mag_weight = mag_weight
304
+ self.cross_gate = cross_gate
305
+ self.max_chunk_size = max_chunk_size
306
+ self.linear_precision = linear_precision
307
+ self.padding_side = padding_side
308
+ self.disable_linear_attn = disable_linear_attn
309
+
310
+ (
311
+ self.num_heads,
312
+ self.head_dim,
313
+ self.num_key_value_heads,
314
+ self.num_key_value_groups,
315
+ self.hidden_dim,
316
+ ) = self._get_attention_parameters(base_attn, base_config)
317
+ self.scaling = self.head_dim**-0.5
318
+
319
+ self.linear_attn = LinearAttention(
320
+ layer_idx=layer_idx,
321
+ shared_attn=True,
322
+ operator_mode=operator_mode,
323
+ use_linear_checkpoint=use_linear_checkpoint,
324
+ recurrent_config=recurrent_config,
325
+ hidden_dim=self.hidden_dim,
326
+ num_heads=self.num_heads,
327
+ head_dim=self.head_dim,
328
+ num_key_value_heads=self.num_key_value_heads,
329
+ num_key_value_groups=self.num_key_value_groups,
330
+ linear_precision=linear_precision,
331
+ linear_cache=linear_cache,
332
+ max_chunk_size=max_chunk_size,
333
+ padding_side=padding_side,
334
+ bidirectional=bidirectional,
335
+ pooling_config=pooling_config,
336
+ )
337
+
338
+ def _get_attention_parameters(
339
+ self, base_attn: nn.Module, base_config: PretrainedConfig
340
+ ) -> Tuple[Optional[int], Optional[int], Optional[int], Optional[int]]:
341
+ """Retrieve the attention parameters from the base attention module."""
342
+ # first order base attention module and second order config
343
+ num_heads = (
344
+ getattr(base_attn, "num_heads", None)
345
+ or getattr(base_attn, "num_q_heads", None)
346
+ or getattr(base_config, "num_heads", None)
347
+ or getattr(base_config, "num_attention_heads", None)
348
+ )
349
+ head_dim = (
350
+ getattr(base_attn, "head_dim", None)
351
+ or getattr(base_attn, "attention_head_size", None)
352
+ or getattr(base_config, "head_dim", None)
353
+ or (
354
+ getattr(base_config, "hidden_size", None) // num_heads
355
+ if num_heads and getattr(base_config, "hidden_size", None)
356
+ else None
357
+ )
358
+ )
359
+ num_key_value_heads = (
360
+ getattr(base_attn, "num_kv_heads", None)
361
+ or getattr(base_attn, "num_k_heads", None)
362
+ or getattr(base_config, "num_key_value_heads", None)
363
+ or num_heads # fallback
364
+ )
365
+ num_key_value_groups = getattr(base_attn, "num_key_value_groups", None) or (
366
+ num_heads // num_key_value_heads if num_heads and num_key_value_heads else 1
367
+ )
368
+ hidden_dim = getattr(base_config, "hidden_size", None) or head_dim * num_heads
369
+ return (
370
+ num_heads,
371
+ head_dim,
372
+ num_key_value_heads,
373
+ num_key_value_groups,
374
+ hidden_dim,
375
+ )
376
+
377
+ def _apply_shared_projections(
378
+ self, hidden_states: torch.Tensor
379
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, nn.Module]:
380
+ base_attn = self.base_attn
381
+ if hasattr(base_attn, "q_proj"):
382
+ # LLama, OLMO and Mistral style
383
+ q = base_attn.q_proj(hidden_states)
384
+ k = base_attn.k_proj(hidden_states)
385
+ v = base_attn.v_proj(hidden_states)
386
+ out_proj = base_attn.o_proj
387
+ elif hasattr(base_attn, "qkv_proj"):
388
+ # OpenELM and GPT-Neo style : QKV fused, split on the last dimension
389
+ qkv = base_attn.qkv_proj(hidden_states)
390
+ q, k, v = split_qkv(base_attn, qkv)
391
+ out_proj = base_attn.out_proj
392
+ elif hasattr(base_attn, "c_attn") and hasattr(base_attn, "c_proj"):
393
+ # GPT-2 style
394
+ qkv = base_attn.c_attn(hidden_states)
395
+ q, k, v = qkv.chunk(3, dim=-1)
396
+ out_proj = base_attn.c_proj
397
+ elif all(hasattr(base_attn, n) for n in ["query", "key", "value"]):
398
+ # BERT - ViT
399
+ q = base_attn.query(hidden_states)
400
+ k = base_attn.key(hidden_states)
401
+ v = base_attn.value(hidden_states)
402
+ out_proj = getattr(base_attn, "dense", None) # ou output.dense
403
+ else:
404
+ raise ValueError("Unsupported attention module: cannot find projections.")
405
+ # Ensure stability
406
+ q = ensure_stability(q, min_val=-1e4, max_val=1e4)
407
+ k = ensure_stability(k, min_val=-1e4, max_val=1e4)
408
+ v = ensure_stability(v, min_val=-1e4, max_val=1e4)
409
+ return q, k, v, out_proj
410
+
411
+ def _process_self_attn(
412
+ self,
413
+ hidden_states: torch.Tensor,
414
+ attention_mask: Optional[torch.Tensor],
415
+ kwargs,
416
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[DynamicCache], int]:
417
+ """Process the self-attention part (with truncation)."""
418
+ if self.max_self_attn_length: # Not needed for SWA (nonparam memorize context)
419
+ hidden_states, attention_mask = truncate_attention_mask(
420
+ hidden_states, attention_mask, self.max_self_attn_length
421
+ )
422
+
423
+ if kwargs.get("position_embeddings", None) is not None:
424
+ cos, sin = kwargs["position_embeddings"]
425
+ cos = cos[:, -self.max_self_attn_length :]
426
+ sin = sin[:, -self.max_self_attn_length :]
427
+ kwargs["position_embeddings"] = (cos, sin)
428
+
429
+ if isinstance(kwargs.get("past_key_value", None), DynamicCache):
430
+ # cache management
431
+ if (
432
+ len(kwargs["past_key_value"]) > self.layer_idx
433
+ and self.layer_idx == 0
434
+ ):
435
+ kwargs["past_key_value"].crop(self.max_self_attn_length - 1)
436
+
437
+ # Ensure attention mask is of the correct dtype and device
438
+ if attention_mask is not None:
439
+ attention_mask = attention_mask.to(
440
+ dtype=hidden_states.dtype, device=hidden_states.device
441
+ )
442
+ # Standard attention (mask and rotation is applied inside)
443
+ base_attn_outputs = self.base_attn(
444
+ hidden_states,
445
+ attention_mask=attention_mask,
446
+ **kwargs,
447
+ )
448
+
449
+ if isinstance(base_attn_outputs, tuple):
450
+ if len(base_attn_outputs) == 3:
451
+ o_base, attn_weights, present_key_value = base_attn_outputs
452
+ expected_attn_mode = 3
453
+ elif len(base_attn_outputs) == 2:
454
+ o_base, attn_weights = base_attn_outputs
455
+ present_key_value, expected_attn_mode = None, 2
456
+ else:
457
+ raise ValueError(
458
+ f"Unexpected number of outputs from base_attn: {len(base_attn_outputs)}"
459
+ )
460
+ else:
461
+ o_base = base_attn_outputs
462
+ attn_weights, present_key_value, expected_attn_mode = None, None, 1
463
+ # Ensure stability
464
+ o_base = ensure_stability(o_base, min_val=-1e4, max_val=1e4)
465
+ return o_base, attn_weights, present_key_value, expected_attn_mode
466
+
467
+ def _prepare_attn_mixin(
468
+ self,
469
+ o_lin: torch.Tensor,
470
+ o_base: torch.Tensor,
471
+ tensor_dtype: torch.dtype,
472
+ eps: float = 1e-5,
473
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
474
+ """Prepare linear attn for mixing with self attn."""
475
+ # Force cast typing, shape : [b n (h d)]
476
+ o_lin = o_lin.to(tensor_dtype)
477
+ o_base = o_base.to(tensor_dtype)
478
+ # feature scaling
479
+ if self.base_scale_attn:
480
+ scaler = o_base.pow(2).mean(dim=-1, keepdim=True).add(eps).sqrt()
481
+ o_lin = scaler * o_lin
482
+ return o_lin, o_base
483
+
484
+ def _apply_mag(
485
+ self, linear_attention: torch.Tensor, softmax_attention: torch.Tensor
486
+ ) -> torch.Tensor:
487
+ """Apply the MAG strategy"""
488
+ # Left-Padding management
489
+ if linear_attention.shape[1] != softmax_attention.shape[1]:
490
+ left_trunc = min(linear_attention.shape[1], softmax_attention.shape[1])
491
+ linear_attention, softmax_attention = (
492
+ linear_attention[:, -left_trunc:],
493
+ softmax_attention[:, -left_trunc:],
494
+ )
495
+ # NAM : Neural Attention Mixer (with graph forcing)
496
+ mag_weight = torch.tensor(
497
+ self.mag_weight,
498
+ dtype=softmax_attention.dtype,
499
+ device=softmax_attention.device,
500
+ )
501
+ softmax_weighted = (1 - mag_weight) * softmax_attention
502
+ linear_weighted = mag_weight * linear_attention
503
+ if self.cross_gate:
504
+ output_attention = (
505
+ softmax_weighted + linear_weighted + softmax_weighted * linear_weighted
506
+ ) # complex cross product (unlinear interaction)
507
+ else:
508
+ output_attention = softmax_weighted + linear_weighted # classic
509
+
510
+ if torch.allclose(softmax_weighted, output_attention):
511
+ logger.info(
512
+ "[LOG] layer : %s, softmax_weighted and output_attention are close.",
513
+ self.layer_idx,
514
+ )
515
+ # Final output
516
+ return ensure_stability(output_attention, min_val=-1e4, max_val=1e4)
517
+
518
+ def forward(
519
+ self,
520
+ hidden_states: torch.Tensor,
521
+ attention_mask: Optional[torch.Tensor] = None,
522
+ **kwargs,
523
+ ) -> torch.Tensor:
524
+ """Mix linear and self attention forward"""
525
+ device = hidden_states.device
526
+ tensor_dtype = hidden_states.dtype
527
+ self.base_attn.to(device)
528
+
529
+ if self.training:
530
+ kwargs.pop("past_key_value", None)
531
+ kwargs["use_cache"] = False
532
+ elif "use_cache" not in kwargs:
533
+ kwargs.pop("past_key_value", None)
534
+ kwargs["use_cache"] = False
535
+
536
+ kwargs.pop("position_ids", None) # obsolete
537
+
538
+ # Apply shared projections
539
+ q, k, v, out_proj = self._apply_shared_projections(hidden_states)
540
+
541
+ # Apply linear attention to hidden states
542
+ o_lin = self.linear_attn(
543
+ x=[q, k, v], attn_mask=attention_mask, out_proj=out_proj, **kwargs
544
+ )
545
+
546
+ # Process self attn with truncation
547
+ o_base, attn_weights, present_key_value, expected_attn_mode = (
548
+ self._process_self_attn(hidden_states, attention_mask, kwargs)
549
+ )
550
+
551
+ # Prepare output mixing
552
+ o_lin, o_base = self._prepare_attn_mixin(o_lin, o_base, tensor_dtype, eps=1e-5)
553
+
554
+ # Apply Memory as Gate in self-attention (with length management and ablation)
555
+ out = o_base if self.disable_linear_attn else self._apply_mag(o_lin, o_base)
556
+
557
+ # Return output following transformer convention
558
+ if expected_attn_mode == 3:
559
+ return out, attn_weights, present_key_value
560
+ if expected_attn_mode == 2:
561
+ return out, attn_weights
562
+ return out
563
+
564
+ @property
565
+ def is_sliding(self):
566
+ """Check if the base attention contain sliding window attention."""
567
+ return getattr(self.base_attn, "is_sliding", False)
568
+
569
+
570
+ def load_tptt_safetensors(
571
+ repo_or_path: str,
572
+ model: Union[PreTrainedModel, PeftModel],
573
+ subfolder: Optional[str] = None,
574
+ token: Optional[str] = None,
575
+ ) -> Union[PreTrainedModel, PeftModel]:
576
+ """Load Tptt safetensor from LoRA/PEFT weights and adapt keys if needed."""
577
+ # sharding not supported yet (e.g. : -00001-of-00005.safetensors, ...)
578
+ fname = "adapter_model.safetensors"
579
+ # subfolder management
580
+ if subfolder:
581
+ repo_or_path_norm = os.path.normpath(repo_or_path)
582
+ subfolder_norm = os.path.normpath(subfolder)
583
+ if not repo_or_path_norm.endswith(subfolder_norm):
584
+ fname = f"{subfolder}/{fname}" if subfolder else fname
585
+ # Find file path
586
+ if os.path.isdir(repo_or_path):
587
+ path = os.path.join(repo_or_path, fname)
588
+ if not os.path.exists(path):
589
+ return model
590
+ else:
591
+ if fname not in list_repo_files(repo_or_path, token=token):
592
+ return model
593
+ path = hf_hub_download(repo_or_path, fname, token=token)
594
+
595
+ # Load weights from safetensors
596
+ with safe_open(path, framework="pt") as f:
597
+ state_dict = {k: f.get_tensor(k) for k in f.keys()}
598
+
599
+ # Adapt LoRA/Specific keys if needed (add .default if expected by the model)
600
+ def adapt_keys(sd, model):
601
+ model_keys = list(model.state_dict().keys())
602
+ if any(k.startswith("tptt_model.base_model.") for k in model_keys):
603
+ prefix = "tptt_model.base_model."
604
+ elif any(k.startswith("base_model.") for k in model_keys):
605
+ prefix = "base_model."
606
+ else:
607
+ prefix = ""
608
+
609
+ has_base_attn = any(".base_attn." in k for k in model_keys)
610
+
611
+ def adapt_key(k):
612
+ k_ = k if k.startswith(prefix) else prefix + k
613
+ # first, verify and modify base_attn (LiZA)
614
+ if ".base_attn." in k_ and not has_base_attn:
615
+ k_ = k_.replace(".base_attn.", ".")
616
+ # change LoRA if needed
617
+ if (
618
+ k_.endswith("lora_A.weight") or k_.endswith("lora_B.weight")
619
+ ) and k_.replace(".weight", ".default.weight") in model_keys:
620
+ k_ = k_.replace(".weight", ".default.weight")
621
+ return k_
622
+
623
+ return {adapt_key(k): v for k, v in sd.items()}
624
+
625
+ state_dict = adapt_keys(state_dict, model)
626
+
627
+ # Cast tensors to the expected dtype of the model parameters
628
+ model_state_dict = model.state_dict()
629
+ for k, v in state_dict.items():
630
+ if k in model_state_dict:
631
+ expected_dtype = model_state_dict[k].dtype
632
+ if v.dtype != expected_dtype:
633
+ state_dict[k] = v.to(expected_dtype)
634
+
635
+ logger.info("Input LoRA/Specific keys: %s", [k for k in state_dict.keys()])
636
+
637
+ # Load into model
638
+ missing, unexpected = model.load_state_dict(state_dict, strict=False, assign=True)
639
+ missing_lora = [k for k in missing if "lora" in k]
640
+ if missing_lora:
641
+ logger.warning("Missing keys: %s", missing_lora)
642
+ if unexpected:
643
+ logger.warning("Unexpected keys: %s", unexpected)
644
+ return model
645
+
646
+
647
+ def get_tptt_model( # pylint: disable=too-many-arguments, too-many-positional-arguments
648
+ model: nn.Module,
649
+ base_config: PretrainedConfig, # ou LlamaConfig, MistralConfig, etc.
650
+ linear_cache: Optional[LCache] = None,
651
+ liza_attention: nn.Module = LiZAttention,
652
+ target_modules_names: Optional[list[str]] = None,
653
+ operator_mode: str = "delta_rule",
654
+ use_linear_checkpoint: bool = False,
655
+ recurrent_config: Optional[Dict[str, Any]] = None,
656
+ base_scale_attn: bool = False,
657
+ mag_weight: float = 0.5,
658
+ cross_gate: bool = False,
659
+ max_chunk_size: int = 64,
660
+ linear_precision: torch.dtype = torch.float32,
661
+ max_self_attn_length: Optional[int] = None, # unnecessary
662
+ padding_side: str = "right", # for tokenizer
663
+ bidirectional: bool = False, # if True, use bidirectional attention
664
+ pooling_config: Optional[Dict[str, Any]] = None,
665
+ **kwargs, # quickfix unexpected arguments
666
+ ) -> Tuple[PreTrainedModel, LCache]:
667
+ """Replace target modules in a model with LiZAttention."""
668
+ if target_modules_names is None:
669
+ target_modules_names = ["attn", "self_attn", "attention"]
670
+ # Find target modules by suffix (e.g., "attn", "attention")
671
+ target_modules_names = [
672
+ name
673
+ for name, _ in model.named_modules()
674
+ if any(name.endswith(suffix) for suffix in target_modules_names)
675
+ and not any(f".{suffix}." in name for suffix in target_modules_names)
676
+ ]
677
+ if not target_modules_names:
678
+ raise ValueError(
679
+ f"Target modules '{target_modules_names}' not found in the model."
680
+ )
681
+ # Prepare recurrent config
682
+ linear_cache = linear_cache or LCache()
683
+ # Inject LiZAttention into the model
684
+ for name, _ in model.named_modules():
685
+ if name in target_modules_names:
686
+ parent = model
687
+ *path, last = name.split(".")
688
+ for p in path:
689
+ parent = getattr(parent, p)
690
+ layer_idx = extract_layer_idx(name)
691
+ setattr(
692
+ parent,
693
+ last,
694
+ liza_attention(
695
+ getattr(parent, last),
696
+ layer_idx=layer_idx,
697
+ base_config=base_config,
698
+ linear_cache=linear_cache,
699
+ operator_mode=operator_mode,
700
+ use_linear_checkpoint=use_linear_checkpoint,
701
+ recurrent_config=recurrent_config,
702
+ max_self_attn_length=max_self_attn_length,
703
+ base_scale_attn=base_scale_attn,
704
+ mag_weight=mag_weight,
705
+ cross_gate=cross_gate,
706
+ max_chunk_size=max_chunk_size,
707
+ linear_precision=linear_precision,
708
+ padding_side=padding_side,
709
+ bidirectional=bidirectional,
710
+ pooling_config=pooling_config,
711
+ ),
712
+ )
713
+ return model, linear_cache
714
+
715
+
716
+ def save_tptt_safetensors(model, path: str, name: str = "adapter_model.safetensors"):
717
+ """Save trainable LoRA/Specific weights and adapting key names"""
718
+ # 1. Get the full state_dict
719
+ all_sd = model.state_dict()
720
+
721
+ # 2. Identify trainable parameter names (usually only LoRA/PEFT adapters)
722
+ trainable_keys = [
723
+ name for name, param in model.named_parameters() if param.requires_grad
724
+ ] # Also, you can manually select specific keys in model after load
725
+
726
+ # 3. Filter and adapt the keys (Remove custom model encapsulation info)
727
+ to_save = {
728
+ k.replace("tptt_model.", "").replace("base_model.", ""): all_sd[k]
729
+ for k in trainable_keys
730
+ }
731
+
732
+ # 4. Save the filtered adapters to a safetensors file
733
+ if to_save:
734
+ os.makedirs(os.path.dirname(path), exist_ok=True)
735
+ # sharding not supported yet (e.g. : -00001-of-00005.safetensors, ...)
736
+ save_file(to_save, os.path.join(path, name))
737
+
738
+
739
+ class TpttModel(PreTrainedModel):
740
+ """
741
+ TPTT model wrapper with linear attention (LiZA) and LoRA support.
742
+ Handles only architecture and weights.
743
+ """
744
+
745
+ config_class = TpttConfig
746
+
747
+ def __init__(
748
+ self,
749
+ config: TpttConfig,
750
+ **kwargs,
751
+ ):
752
+ """
753
+ Initialize TpttModel with a given config and backbone.
754
+ Injects LiZA attention modules into the backbone.
755
+ """
756
+ super().__init__(config, **kwargs)
757
+ repo_or_path = getattr(config, "_base_path", None) or config._name_or_path
758
+
759
+ # 1. Load backbone (with subfolder management) :
760
+ kwargs_bb = kwargs.copy()
761
+ if config.base_model_subfolder is not None:
762
+ kwargs_bb["subfolder"] = config.base_model_subfolder
763
+ else:
764
+ kwargs_bb.pop("subfolder", None)
765
+
766
+ if config.model_task == "causal_lm":
767
+ tptt_model = AutoModelForCausalLM.from_pretrained(
768
+ config.base_model_name, **kwargs_bb
769
+ )
770
+ else:
771
+ tptt_model = AutoModel.from_pretrained(config.base_model_name, **kwargs_bb)
772
+
773
+ # 2. Inject LiZA attention
774
+ self.linear_cache = LCache()
775
+ tptt_model, self.linear_cache = get_tptt_model(
776
+ tptt_model, config, self.linear_cache, **config.to_dict()
777
+ )
778
+
779
+ # 3. Apply LoRA/Specific if present and configured
780
+ if config.lora_config is not None:
781
+ lora_config_obj = LoraConfig(**config.lora_config)
782
+ tptt_model = get_peft_model(tptt_model, lora_config_obj)
783
+ else:
784
+ # Doesn't work if quantization is applied !
785
+ tptt_model = set_trainable_parameters(tptt_model)
786
+
787
+ # 4. Load safetensor if tptt/peft adaptor in repo
788
+ if repo_or_path:
789
+ tptt_model = load_tptt_safetensors(
790
+ repo_or_path,
791
+ tptt_model,
792
+ subfolder=kwargs.get("subfolder", None),
793
+ token=kwargs.get("token", None),
794
+ )
795
+ self.tptt_model = tptt_model
796
+
797
+ def forward(
798
+ self,
799
+ input_ids: Optional[torch.LongTensor] = None,
800
+ attention_mask: Optional[torch.Tensor] = None,
801
+ labels: Optional[torch.LongTensor] = None,
802
+ **kwargs,
803
+ ):
804
+ """Forward pass. All arguments are passed to the underlying base model."""
805
+ if self.training:
806
+ kwargs["use_cache"] = False
807
+ kwargs.pop("num_items_in_batch", None)
808
+ elif "use_cache" not in kwargs: # evaluation
809
+ kwargs.pop("num_items_in_batch", None)
810
+ kwargs["use_cache"] = False
811
+ return self.tptt_model(
812
+ input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs
813
+ )
814
+
815
+ def generate(self, *args, **kwargs):
816
+ """Delegate the generate call to the backbone model, which supports generation"""
817
+ return self.tptt_model.generate(*args, **kwargs)
818
+
819
+ def save_pretrained(self, path: str, **kwargs):
820
+ """Save model weights, config, and source code to the given path."""
821
+ # 0. Save complete tptt config (with or without LoRA)
822
+ super().save_pretrained(path, **kwargs) # pylint: disable=no-member
823
+ self._adjust_save_strategy(path, **kwargs)
824
+ # 1. Save true weights and adapte keys
825
+ save_tptt_safetensors(self, path)
826
+ # 2. Copy Python files for trust_remote_code
827
+ self._copy_source_files(path, **kwargs)
828
+
829
+ def _adjust_save_strategy(self, path: str, **kwargs):
830
+ """Re-adapt/remove the weight safetensor and saved adapter config"""
831
+ if isinstance(self.tptt_model, PeftModel):
832
+ self.tptt_model.save_pretrained(path, **kwargs)
833
+ safetensor_path = os.path.join(path, "model.safetensors")
834
+ if os.path.exists(safetensor_path):
835
+ os.remove(safetensor_path)
836
+ adapter_path = os.path.join(path, "adapter_config.json")
837
+ if os.path.exists(adapter_path):
838
+ os.remove(adapter_path)
839
+
840
+ def _copy_source_files(self, target_path: str, **kwargs):
841
+ """Copy all .py files from package directory for trust_remote_code."""
842
+ src_dir = os.path.dirname(os.path.abspath(__file__))
843
+ dst_dir = (
844
+ f"./{str(Path(target_path).parts[0])}"
845
+ if kwargs.get("subfolder", False)
846
+ else target_path
847
+ )
848
+ for fname in os.listdir(src_dir):
849
+ if fname.endswith(".py"):
850
+ src = os.path.join(src_dir, fname)
851
+ dst = os.path.join(dst_dir, fname)
852
+ shutil.copy2(src, dst)
853
+
854
+ def retie_lm_after_load(self, **kwargs):
855
+ """Re-link lm_head after loading external weights."""
856
+ embed_lm = find_embedding_lm(self.tptt_model)
857
+ if embed_lm is not None and hasattr(self.tptt_model, "lm_head"):
858
+ if self.tptt_model.lm_head is None: # ensure lm_head exists
859
+ self.tptt_model.lm_head = nn.Linear(
860
+ embed_lm.weight.shape[1], embed_lm.weight.shape[0], bias=False
861
+ )
862
+ if kwargs.get("tie_word_embeddings", True):
863
+ self.tptt_model.lm_head.weight = embed_lm.weight # share weights
864
+ logger.info("Weights of lm_head have been shared with embedding.")
865
+ else:
866
+ self.tptt_model.lm_head.weight = nn.Parameter(embed_lm.weight.clone())
867
+ logger.info("Weights of lm_head have been cloned from the embedding.")
868
+
869
+ @classmethod
870
+ def from_pretrained(cls, pretrained_model_name_or_path=None, *model_args, **kwargs):
871
+ """Custom from_pretrained that accepts the standard positional argument"""
872
+ config = kwargs.pop("config", None)
873
+ repo_or_path = (
874
+ pretrained_model_name_or_path
875
+ or kwargs.pop("pretrained_model_name_or_path", None)
876
+ or kwargs.pop("repo_or_path", None)
877
+ or (getattr(config, "_base_path", None) if config else None)
878
+ or (getattr(config, "_name_or_path", None) if config else None)
879
+ )
880
+
881
+ if config is None and repo_or_path is not None:
882
+ config = AutoConfig.from_pretrained(repo_or_path, **kwargs)
883
+ model = cls(config, *model_args, **kwargs)
884
+ model.retie_lm_after_load(**kwargs)
885
+ return model
886
+
887
+
888
+ TpttModel.register_for_auto_class("AutoModelForCausalLM")
889
+
890
+
891
+ class LinearAttentionOp(nn.Module):
892
+ """Base class for linear attention operators."""
893
+
894
+ def __init__(
895
+ self,
896
+ layer_idx: int,
897
+ operator_mode: str = "delta_rule",
898
+ use_linear_checkpoint: bool = False,
899
+ recurrent_config: Optional[dict] = None,
900
+ max_chunk_size: int = 64,
901
+ linear_cache: Optional[LCache] = None,
902
+ linear_precision: torch.dtype = torch.float32,
903
+ ):
904
+ super().__init__()
905
+ self.layer_idx = layer_idx
906
+ if recurrent_config is None:
907
+ operator_mode = "delta_rule" # force default operator mode if no config
908
+ recurrent_config = {
909
+ "order": 1,
910
+ "gate_type": "k",
911
+ "linear": True,
912
+ "trick": "derivative",
913
+ }
914
+ self.operator_mode = operator_mode
915
+ self.use_linear_checkpoint = use_linear_checkpoint
916
+
917
+ self.order = recurrent_config["order"]
918
+ self.gate_type = recurrent_config["gate_type"]
919
+ self.linear = recurrent_config["linear"]
920
+ self.trick = recurrent_config["trick"]
921
+
922
+ self.max_chunk_size = max_chunk_size
923
+ self.linear_cache = linear_cache or LCache()
924
+ self.linear_precision = linear_precision
925
+
926
+ def compute_gate(self, beta: Tuple[torch.Tensor]) -> torch.Tensor:
927
+ """
928
+ Compute the gating tensor according to the gate_type.
929
+ """
930
+ if self.gate_type == "k":
931
+ return torch.clamp(beta[0], min=1e-6, max=1 - 1e-6)
932
+ if self.gate_type == "v":
933
+ return torch.clamp(beta[1], min=1e-6, max=1 - 1e-6)
934
+ if self.gate_type == "kv":
935
+ return torch.clamp(beta[0] * beta[1], min=1e-6, max=1 - 1e-6)
936
+ raise ValueError(f"Unsupported gate_type: {self.gate_type}")
937
+
938
+ def get_cache(self, use_cache: bool) -> Tuple[
939
+ Optional[torch.Tensor],
940
+ Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
941
+ ]:
942
+ """
943
+ Retrieve recurrent state and qkv buffers from the cache.
944
+ """
945
+ if not use_cache:
946
+ return None, None
947
+ last_state = self.linear_cache[self.layer_idx]
948
+ if last_state is not None:
949
+ recurrent_state = last_state.get("recurrent_state", None)
950
+ qkv_buffers = last_state.get("qkv", None)
951
+ else:
952
+ recurrent_state = None
953
+ qkv_buffers = None
954
+ return recurrent_state, qkv_buffers
955
+
956
+ def save_cache(
957
+ self,
958
+ use_cache: bool,
959
+ q: torch.Tensor,
960
+ k: torch.Tensor,
961
+ v: torch.Tensor,
962
+ gate: torch.Tensor,
963
+ state: torch.Tensor,
964
+ ) -> None:
965
+ """
966
+ Save the recurrent state and qkv buffers to the cache.
967
+ """
968
+ if not use_cache:
969
+ return
970
+ if self.order > 1:
971
+ qkv_buffers = (
972
+ q[:, :, -(self.order - 1) :, :],
973
+ k[:, :, -(self.order - 1) :, :],
974
+ v[:, :, -(self.order - 1) :, :],
975
+ gate[:, :, -(self.order - 1) :, :],
976
+ )
977
+ else:
978
+ qkv_buffers = None
979
+ self.linear_cache.update(self.layer_idx, recurrent_state=state, qkv=qkv_buffers)
980
+
981
+ def forward(
982
+ self,
983
+ q: torch.Tensor,
984
+ k: torch.Tensor,
985
+ v: torch.Tensor,
986
+ beta: Union[Tuple[torch.Tensor], torch.Tensor],
987
+ **kwargs,
988
+ ) -> torch.Tensor:
989
+ """
990
+ Forward pass for the attention operator.
991
+ """
992
+ # Ensure linear_precision for numerical stability (float32)
993
+ q, k, v = [x.to(self.linear_precision) for x in (q, k, v)]
994
+ if isinstance(beta, (tuple, list)):
995
+ beta = tuple(b.to(self.linear_precision) for b in beta)
996
+ else:
997
+ beta = beta.to(self.linear_precision)
998
+
999
+ gate = self.compute_gate(beta)
1000
+
1001
+ # Retrieve cache if needed
1002
+ use_cache = kwargs.get("use_cache", False)
1003
+ use_checkpoint = not (use_cache) and self.use_linear_checkpoint
1004
+ recurrent_state, qkvb = self.get_cache(use_cache)
1005
+
1006
+ if qkvb is not None and qkvb[0].shape == q.shape:
1007
+ q = torch.cat([qkvb[0].to(q.device), q], dim=2).to(self.linear_precision)
1008
+ k = torch.cat([qkvb[1].to(q.device), k], dim=2).to(self.linear_precision)
1009
+ v = torch.cat([qkvb[2].to(q.device), v], dim=2).to(self.linear_precision)
1010
+ gate = torch.cat([qkvb[3].to(q.device), gate], dim=2).to(
1011
+ self.linear_precision
1012
+ )
1013
+
1014
+ output, state = self.chunk_delta_product_forward(
1015
+ q,
1016
+ k,
1017
+ v,
1018
+ gate,
1019
+ self.max_chunk_size,
1020
+ n=self.order,
1021
+ trick=self.trick,
1022
+ linear=self.linear,
1023
+ initial_state=recurrent_state,
1024
+ use_checkpoint=use_checkpoint,
1025
+ linear_precision=self.linear_precision,
1026
+ )
1027
+
1028
+ # Save cache if needed
1029
+ self.save_cache(use_cache, q, k, v, gate, state)
1030
+
1031
+ return output
1032
+
1033
+ @staticmethod
1034
+ def chunk_delta_product_forward(
1035
+ query: torch.Tensor,
1036
+ key: torch.Tensor,
1037
+ value: torch.Tensor,
1038
+ beta_gate: torch.Tensor,
1039
+ chunk_size: int,
1040
+ n: int = 1,
1041
+ trick: str = "derivative",
1042
+ linear: bool = True,
1043
+ initial_state: Optional[torch.Tensor] = None,
1044
+ use_checkpoint: bool = True,
1045
+ linear_precision: torch.dtype = torch.float32,
1046
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1047
+ """
1048
+ Chunkwise parallel implementation https://arxiv.org/abs/2406.06484
1049
+ For each chunk, processes chunk_size * n_orders steps (virtual tokens) in order.
1050
+ """
1051
+
1052
+ # --- Main chunk_delta_product_forward logic ---
1053
+
1054
+ batch_size, num_heads, seq_len, head_dim = query.shape
1055
+ chunk_size = get_valid_chunk_size(seq_len, chunk_size)
1056
+ num_chunks = seq_len // chunk_size
1057
+
1058
+ query_n = query if n == 1 else expand_virtual_tokens(query, n, trick)
1059
+ key_n = key if n == 1 else expand_virtual_tokens(key, n, trick)
1060
+ value_n = value if n == 1 else expand_virtual_tokens(value, n, trick)
1061
+ beta_n = beta_gate if n == 1 else expand_virtual_tokens(beta_gate, n, trick)
1062
+
1063
+ q_chunks = chunk_sequence(query_n, num_chunks, chunk_size * n)
1064
+ k_chunks = chunk_sequence(key_n, num_chunks, chunk_size * n)
1065
+ v_chunks = chunk_sequence(value_n, num_chunks, chunk_size * n)
1066
+ beta_chunks = chunk_sequence(beta_n, num_chunks, chunk_size * n)
1067
+
1068
+ k_beta = k_chunks * beta_chunks
1069
+ v_beta = v_chunks * beta_chunks
1070
+
1071
+ householder = -(k_beta @ k_chunks.transpose(-2, -1)).tril(-1)
1072
+ householder = ensure_stability(householder, min_val=-1e4, max_val=1e4)
1073
+
1074
+ # size : N = chunk_size * n
1075
+ inv_hh = fast_invert_matrix(householder, dtype=linear_precision) # [(...),N,N]
1076
+
1077
+ w = ensure_stability(torch.matmul(inv_hh, k_beta), min_val=-1e4, max_val=1e4)
1078
+ u = ensure_stability(torch.matmul(inv_hh, v_beta), min_val=-1e4, max_val=1e4)
1079
+
1080
+ state_shape = (batch_size, num_heads, n, head_dim, head_dim)
1081
+ if initial_state is not None and initial_state.shape == state_shape:
1082
+ state = initial_state.to(device=query.device, dtype=linear_precision)
1083
+ else:
1084
+ state = torch.full(
1085
+ state_shape,
1086
+ fill_value=1e-6, # stability if unlinear activation
1087
+ device=query.device,
1088
+ dtype=linear_precision,
1089
+ )
1090
+
1091
+ output, final_state = sequential_delta_product_scan(
1092
+ q_chunks.to(dtype=linear_precision),
1093
+ w.to(dtype=linear_precision),
1094
+ u.to(dtype=linear_precision),
1095
+ n,
1096
+ linear,
1097
+ chunk_size,
1098
+ state.to(dtype=linear_precision),
1099
+ linear_precision=linear_precision,
1100
+ use_checkpoint=use_checkpoint,
1101
+ )
1102
+
1103
+ idx_last_order = torch.arange(chunk_size, device=output.device) * n + (n - 1)
1104
+ output = output[:, :, :, idx_last_order, :] # [B, H, num_chunks, chunk_size, D]
1105
+ output = output.reshape(batch_size, num_heads, seq_len, head_dim)
1106
+
1107
+ return output.to(dtype=linear_precision), final_state.to(dtype=linear_precision)
1108
+
1109
+
1110
+ def sequential_delta_product_scan(
1111
+ q_chunks: torch.Tensor,
1112
+ w: torch.Tensor,
1113
+ u: torch.Tensor,
1114
+ n_orders: int,
1115
+ linear_activation: bool,
1116
+ current_chunk_size: int,
1117
+ initial_recurrent_state: torch.Tensor,
1118
+ linear_precision: torch.dtype,
1119
+ use_checkpoint: bool,
1120
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1121
+ """
1122
+ DeltaProduct implementation https://arxiv.org/abs/2502.10297
1123
+ Implements the per-token Householder state updates.
1124
+ """
1125
+ batch, head, num_chunks_inner, chunk_n_total, dim = q_chunks.shape
1126
+ output_inner = torch.empty_like(q_chunks)
1127
+ # initial_recurrent_state is H_{last_token_of_prev_chunk, n-1} ([B, H, D, D])
1128
+ h_0_base = initial_recurrent_state[:, :, -1, :, :].clone()
1129
+
1130
+ def process_one_chunk(
1131
+ q_chunk_params: torch.Tensor,
1132
+ w_chunk_params: torch.Tensor,
1133
+ u_chunk_params: torch.Tensor,
1134
+ h_0_base: torch.Tensor,
1135
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1136
+ """
1137
+ Process a single chunk (with per-token state for n_orders > 1).
1138
+ """
1139
+ o_intra_current_chunk = torch.zeros(
1140
+ batch,
1141
+ head,
1142
+ chunk_n_total,
1143
+ dim,
1144
+ device=q_chunk_params.device,
1145
+ dtype=linear_precision,
1146
+ )
1147
+ o_inter_current_chunk = torch.zeros_like(o_intra_current_chunk)
1148
+ current_accumulated_state_per_token = (
1149
+ h_0_base.unsqueeze(2).expand(-1, -1, current_chunk_size, -1, -1).clone()
1150
+ ) # [B, H, current_chunk_size, D, D]
1151
+
1152
+ for step in range(n_orders):
1153
+ idx_virtual_tokens = (
1154
+ torch.arange(current_chunk_size, device=q_chunk_params.device)
1155
+ * n_orders
1156
+ + step
1157
+ )
1158
+ q_s = q_chunk_params[:, :, idx_virtual_tokens, :]
1159
+ w_s = w_chunk_params[:, :, idx_virtual_tokens, :]
1160
+ u_s = u_chunk_params[:, :, idx_virtual_tokens, :]
1161
+
1162
+ state_input_for_this_step = current_accumulated_state_per_token
1163
+
1164
+ ## BLAS/cuBLAS einsum "bhcd,bhcdd->bhcd"
1165
+ k_trans_h_old = (
1166
+ torch.matmul(
1167
+ w_s.unsqueeze(-2),
1168
+ state_input_for_this_step,
1169
+ )
1170
+ .squeeze(-2)
1171
+ .to(dtype=linear_precision)
1172
+ )
1173
+
1174
+ u_val = u_s - k_trans_h_old
1175
+
1176
+ o_inter_current_chunk[:, :, idx_virtual_tokens, :] = (
1177
+ torch.matmul(q_s.unsqueeze(-2), state_input_for_this_step)
1178
+ .squeeze(-2)
1179
+ .to(dtype=linear_precision)
1180
+ )
1181
+
1182
+ ## BLAS/cuBLAS einsum "bhcd,bhcd->bhcd"
1183
+ o_intra_current_chunk[:, :, idx_virtual_tokens, :] = (q_s * u_val).to(
1184
+ dtype=linear_precision
1185
+ )
1186
+
1187
+ outer_product_term = torch.matmul(w_s.unsqueeze(-1), u_val.unsqueeze(-2))
1188
+ new_state_i_per_token = state_input_for_this_step + outer_product_term
1189
+ current_accumulated_state_per_token = new_state_i_per_token.to(
1190
+ dtype=linear_precision
1191
+ )
1192
+ # Return all needed for next chunk
1193
+ return (
1194
+ o_intra_current_chunk,
1195
+ o_inter_current_chunk,
1196
+ current_accumulated_state_per_token[:, :, -1, :, :], # new h_0_base
1197
+ )
1198
+
1199
+ for chunk_idx_inner in range(num_chunks_inner):
1200
+ q_chunk_params = q_chunks[:, :, chunk_idx_inner]
1201
+ w_chunk_params = w[:, :, chunk_idx_inner]
1202
+ u_chunk_params = u[:, :, chunk_idx_inner]
1203
+
1204
+ # Checkpointed call if training
1205
+ call = (
1206
+ partial(checkpoint, use_reentrant=False)
1207
+ if use_checkpoint
1208
+ else lambda f, *a: f(*a)
1209
+ )
1210
+ o_intra, o_inter, h_0_base = call(
1211
+ process_one_chunk,
1212
+ q_chunk_params,
1213
+ w_chunk_params,
1214
+ u_chunk_params,
1215
+ h_0_base,
1216
+ )
1217
+ if not linear_activation: # unlinear activation between chunks
1218
+ h_0_base = unlinear_activation(h_0_base).to(dtype=linear_precision)
1219
+ output_inner[:, :, chunk_idx_inner] = o_intra + o_inter
1220
+
1221
+ return output_inner, h_0_base
1222
+
1223
+
1224
+ def unlinear_activation(x: torch.Tensor, scale: float = 2.0) -> torch.Tensor:
1225
+ """Unlinear activation between chunk"""
1226
+ x_n = x.norm(p=2, dim=-1, keepdim=True) + 1e-6
1227
+ x_gelu = F.gelu(scale * x / x_n, approximate="tanh") # pylint: disable=not-callable
1228
+ return (x / scale) * x_gelu
1229
+
1230
+
1231
+ def chunk_sequence(x: torch.Tensor, num_chunks: int, chunk_size: int) -> torch.Tensor:
1232
+ """Splits [B, H, S, D] to [B, H, num_chunks, chunk_size, D]"""
1233
+ batch_size, num_heads, _, head_dim = x.shape
1234
+ return x.reshape(batch_size, num_heads, num_chunks, chunk_size, head_dim)
1235
+
1236
+
1237
+ def expand_virtual_tokens(
1238
+ x: torch.Tensor, n: int, mode: str = "derivative"
1239
+ ) -> torch.Tensor:
1240
+ """Expand tokens into 'n' virtual tokens using the selected trick."""
1241
+ batch_size, num_heads, seq_len, head_dim = x.shape
1242
+ device, dtype = x.device, x.dtype
1243
+
1244
+ def derivative_expand(x: torch.Tensor) -> torch.Tensor:
1245
+ """Expand tokens using the derivative trick."""
1246
+ x_pad = torch.cat(
1247
+ [
1248
+ torch.zeros(
1249
+ batch_size, num_heads, n - 1, head_dim, device=device, dtype=dtype
1250
+ ),
1251
+ x,
1252
+ ],
1253
+ dim=2,
1254
+ )
1255
+ coeffs = torch.tensor(
1256
+ [(-1) ** k * math.comb(n - 1, k) for k in range(n)],
1257
+ device=device,
1258
+ dtype=dtype,
1259
+ )
1260
+ coeffs /= coeffs.norm(p=1)
1261
+ return (
1262
+ (x_pad.unfold(2, n, 1) * coeffs.view(1, 1, 1, 1, n))
1263
+ .flip(-1)
1264
+ .permute(0, 1, 2, 4, 3)
1265
+ .reshape(batch_size, num_heads, seq_len * n, head_dim)
1266
+ )
1267
+
1268
+ def rotative_expand(x: torch.Tensor) -> torch.Tensor:
1269
+ """Expand tokens using the rotative trick."""
1270
+ d_parity = head_dim // 2
1271
+ angles = torch.arange(n, device=device, dtype=dtype) * (2 * math.pi / n)
1272
+ cos = torch.cos(angles).view(1, 1, 1, n, 1)
1273
+ sin = torch.sin(angles).view(1, 1, 1, n, 1)
1274
+ if head_dim % 2:
1275
+ x_pairs = x[..., :-1].view(batch_size, num_heads, seq_len, d_parity, 2)
1276
+ else:
1277
+ x_pairs = x.view(batch_size, num_heads, seq_len, d_parity, 2)
1278
+ x_pairs = x_pairs.unsqueeze(3).expand(
1279
+ batch_size, num_heads, seq_len, n, d_parity, 2
1280
+ )
1281
+ x0, x1 = x_pairs[..., 0], x_pairs[..., 1]
1282
+ x0r = x0 * cos - x1 * sin
1283
+ x1r = x0 * sin + x1 * cos
1284
+ rot = torch.stack([x0r, x1r], -1).reshape(
1285
+ batch_size, num_heads, seq_len, n, d_parity * 2
1286
+ )
1287
+ if head_dim % 2:
1288
+ last = (
1289
+ x[..., -1]
1290
+ .unsqueeze(-1)
1291
+ .unsqueeze(3)
1292
+ .expand(batch_size, num_heads, seq_len, n, 1)
1293
+ )
1294
+ rot = torch.cat([rot, last], -1)
1295
+ return rot.reshape(batch_size, num_heads, seq_len * n, head_dim)
1296
+
1297
+ if mode == "derivative":
1298
+ return derivative_expand(x)
1299
+ if mode == "rotative":
1300
+ return rotative_expand(x)
1301
+ if mode == "combined":
1302
+ return (derivative_expand(x) + rotative_expand(x)) / 2
1303
+ raise ValueError(f"Unknown mode: {mode}")
1304
+
1305
+
1306
+ def extract_layer_idx(module_name: str) -> int:
1307
+ """Extract the layer index from a module name string."""
1308
+ match = re.search(r"\.(\d+)\.", module_name)
1309
+ if match:
1310
+ return int(match.group(1))
1311
+ return -1
1312
+
1313
+
1314
+ def find_embedding_lm(module: nn.Module) -> Optional[nn.Module]:
1315
+ """Find the embedding weight in a model module."""
1316
+ for _, child in module.named_modules():
1317
+ if hasattr(child, "embed_tokens") and hasattr(child.embed_tokens, "weight"):
1318
+ return child.embed_tokens
1319
+ if hasattr(child, "token_embeddings") and hasattr(
1320
+ child.token_embeddings, "weight"
1321
+ ):
1322
+ return child.token_embeddings
1323
+ return None
1324
+
1325
+
1326
+ def set_trainable_parameters(
1327
+ model: PreTrainedModel, trainable_patterns: List[str] = None
1328
+ ) -> PreTrainedModel:
1329
+ """Freeze model parameters except trainable_patterns."""
1330
+ if trainable_patterns is None:
1331
+ trainable_patterns = [
1332
+ "q_proj",
1333
+ "k_proj",
1334
+ "v_proj",
1335
+ "o_proj",
1336
+ "qkv_proj",
1337
+ "out_proj",
1338
+ "c_attn",
1339
+ "c_proj",
1340
+ "query",
1341
+ "key",
1342
+ "value",
1343
+ ]
1344
+
1345
+ for name, param in model.named_parameters():
1346
+ param.requires_grad = any(pattern in name for pattern in trainable_patterns)
1347
+
1348
+ trainable_layers = [n for n, p in model.named_parameters() if p.requires_grad]
1349
+ logger.info("Trainable parameters after freeze: %s", trainable_layers)
1350
+ return model
1351
+
1352
+
1353
+ def ensure_stability(
1354
+ tensor: torch.Tensor, min_val: float = -1e4, max_val: float = 1e4
1355
+ ) -> torch.Tensor:
1356
+ """stability forcing"""
1357
+ dtype = tensor.dtype
1358
+ center = (max_val + min_val) / 2
1359
+ tensor = torch.clamp(tensor, min=min_val, max=max_val)
1360
+ tensor = torch.nan_to_num(tensor, nan=center, posinf=max_val, neginf=min_val)
1361
+ return tensor.to(dtype=dtype)
1362
+
1363
+
1364
+ def apply_linear_attention_mask(
1365
+ attention_mask: torch.Tensor, v: torch.Tensor, padding_side: str = "right"
1366
+ ) -> torch.Tensor:
1367
+ """Extract if padding --> [B,S]"""
1368
+ if attention_mask.dim() == 4 and attention_mask.shape[1] == 1:
1369
+ mask = attention_mask.diagonal(dim1=-2, dim2=-1).squeeze(1)
1370
+ else:
1371
+ mask = attention_mask.squeeze(
1372
+ dim=tuple(
1373
+ i
1374
+ for i in range(1, attention_mask.dim())
1375
+ if attention_mask.shape[i] == 1
1376
+ )
1377
+ )
1378
+ # Ensure cast to the same dtype as v and convert to binary mask
1379
+ if not (
1380
+ mask.dtype == torch.bool
1381
+ or (
1382
+ mask.dtype in [torch.uint8, torch.int32, torch.int64]
1383
+ and mask.max() <= 1
1384
+ and mask.min() >= 0
1385
+ )
1386
+ ):
1387
+ mask = (mask >= 0).to(v.dtype) # [-inf, 0, 0, -inf] --> [0, 1, 1, 0]
1388
+ else:
1389
+ mask = mask.to(v.dtype)
1390
+ # mask is [batch, seq] --> Broadcast to v [batch, seq, (...)]
1391
+ if padding_side == "left":
1392
+ mask = mask[:, -v.shape[-2] :][(...,) + (None,) * (v.dim() - 2)]
1393
+ else: # right padding
1394
+ mask = mask[:, : v.shape[-2]][(...,) + (None,) * (v.dim() - 2)]
1395
+ return v * mask
1396
+
1397
+
1398
+ def truncate_attention_mask(
1399
+ hidden_states: torch.Tensor, attention_mask: torch.Tensor, max_length: int
1400
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1401
+ """Truncate hidden_states and attention_mask to the last window of size max_length"""
1402
+ seq_dim = 1 # convention: (batch, seq, ...)
1403
+ seq_len = hidden_states.shape[seq_dim]
1404
+ if seq_len > max_length:
1405
+ hidden_states = hidden_states.narrow(seq_dim, seq_len - max_length, max_length)
1406
+ if attention_mask is not None:
1407
+ # mask [batch, seq]
1408
+ if attention_mask.dim() == 2:
1409
+ attention_mask = attention_mask[:, -max_length:]
1410
+ # mask [batch, seq, seq]
1411
+ elif attention_mask.dim() == 3:
1412
+ attention_mask = attention_mask[:, -max_length:, -max_length:]
1413
+ # mask [batch, 1, seq, seq]
1414
+ elif attention_mask.dim() == 4 and attention_mask.shape[1] == 1:
1415
+ attention_mask = attention_mask[:, :, -max_length:, -max_length:]
1416
+ else:
1417
+ raise ValueError(
1418
+ "No dimension in attention_mask matches sequence length of hidden_states."
1419
+ )
1420
+ return hidden_states, attention_mask
1421
+
1422
+
1423
+ def fast_invert_matrix(
1424
+ tri_tensor: torch.Tensor, dtype: torch.dtype = torch.float32
1425
+ ) -> torch.Tensor:
1426
+ """Equivalent to vectorized forward substitution applied to the identity matrix."""
1427
+ tri_tensor = tri_tensor.to(dtype=dtype).clone()
1428
+ chunk_size = tri_tensor.shape[-1]
1429
+
1430
+ for i in range(1, chunk_size):
1431
+ tri_tensor[..., i, :i] = tri_tensor[..., i, :i] + (
1432
+ tri_tensor[..., i, :, None].clone() * tri_tensor[..., :, :i].clone()
1433
+ ).sum(-2)
1434
+
1435
+ tri_tensor = tri_tensor + torch.eye(
1436
+ chunk_size, dtype=dtype, device=tri_tensor.device
1437
+ )
1438
+ return tri_tensor.to(dtype=dtype)
1439
+
1440
+
1441
+ def get_valid_chunk_size(total_l: int, chunk_size: int) -> int:
1442
+ """Return the largest chunk_size <= chunk_size that divides total_l."""
1443
+ for c in range(min(chunk_size, total_l), 0, -1):
1444
+ if total_l % c == 0:
1445
+ return c
1446
+ return 1
1447
+
1448
+
1449
+ ## RARELY
1450
+ def split_qkv(
1451
+ base_attn: nn.Module, qkv: torch.Tensor
1452
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1453
+ """Split the QKV tensor into separate Q, K, and V tensors."""
1454
+ num_q_heads = getattr(base_attn, "num_q_heads", None)
1455
+ num_k_heads = getattr(base_attn, "num_k_heads", None)
1456
+ num_v_heads = getattr(base_attn, "num_v_heads", None)
1457
+ head_dim = getattr(base_attn, "head_dim", None)
1458
+
1459
+ if num_q_heads is None or num_k_heads is None or num_v_heads is None:
1460
+ raise ValueError(
1461
+ "Base attention must have num_q_heads, num_k_heads, and num_v_heads defined."
1462
+ )
1463
+
1464
+ q_len = num_q_heads * head_dim
1465
+ k_len = num_k_heads * head_dim
1466
+ v_len = num_v_heads * head_dim
1467
+
1468
+ q, k, v = torch.split(qkv, [q_len, k_len, v_len], dim=-1)
1469
+ return q, k, v
1470
+
1471
+
1472
+ ## OPTIONAL
1473
+ def match_dim(x: torch.Tensor, dim: int, target_size: int) -> torch.Tensor:
1474
+ """Match the size of tensor x along dimension dim to target_size by interpolation"""
1475
+ src_size = x.shape[dim]
1476
+ if src_size == target_size:
1477
+ return x
1478
+ x = torch.moveaxis(x, dim, -1)
1479
+ shape = x.shape
1480
+ if src_size < target_size:
1481
+ x = x.reshape(-1, 1, src_size)
1482
+ x = F.interpolate(x, size=target_size, mode="linear", align_corners=False)
1483
+ x = x.reshape(*shape[:-1], target_size)
1484
+ else:
1485
+ eye = torch.eye(target_size, src_size, device=x.device, dtype=x.dtype)
1486
+ x = F.linear(x, eye) # pylint: disable=not-callable
1487
+ x = torch.moveaxis(x, -1, dim)
1488
+ return x
1489
+
1490
+
1491
+ def soft_clamp(
1492
+ x: torch.Tensor, min_val: float = 1e-6, max_val: float = 1 - 1e-6
1493
+ ) -> torch.Tensor:
1494
+ """Differentiable clamping for stability"""
1495
+ dtype = x.dtype
1496
+ scale = (max_val - min_val) / 2
1497
+ center = (max_val + min_val) / 2
1498
+ return (torch.tanh((x - center) / scale) * scale + center).to(dtype=dtype)
1499
+
1500
+
1501
+ def describe(x: torch.Tensor, name="tensor") -> None:
1502
+ """Prints the shape, min, max, mean, and std of a tensor."""
1503
+ stats = (x.min(), x.max(), x.mean(), x.std())
1504
+ print(
1505
+ f"{name} shape: {tuple(x.shape)}, "
1506
+ + f"min: {stats[0]:.4g}, max: {stats[1]:.4g}, "
1507
+ + f"mean: {stats[2]:.4g}, std: {stats[3]:.4g}, "
1508
+ + f"dtype: {x.dtype}, device: {x.device}"
1509
+ )
train_tptt.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=too-many-arguments, too-many-positional-arguments
2
+
3
+ """
4
+ Author : Fabien FURFARO
5
+ """
6
+
7
+ from typing import Optional, Union
8
+
9
+ from transformers import PreTrainedModel, TrainerCallback
10
+
11
+ from .modeling_tptt import LiZAttention
12
+
13
+
14
+ class LiZACallback(TrainerCallback):
15
+ """
16
+ TrainerCallback to schedule mag_weight or enable/disable linear attention during training.
17
+
18
+ Modes:
19
+ - "gradual": linear interpolation from initial_weight to final_weight.
20
+ - "cyclic": alternate between values in weight_list at each step.
21
+ - "switch": alternately enable/disable linear attention at each step.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ model: PreTrainedModel,
27
+ mode: str = "gradual",
28
+ initial_weight: float = 0.0,
29
+ final_weight: float = 0.5,
30
+ transition_step: Union[int, tuple, list] = 100,
31
+ weight_list: Optional[list] = None,
32
+ switch_period: int = 1, # period for switching
33
+ ):
34
+ self.model = model
35
+ self.mode = mode
36
+
37
+ # Ensure initial_weight is a float scalar, not tuple/list
38
+ if isinstance(initial_weight, (tuple, list)):
39
+ initial_weight = initial_weight[0]
40
+ if isinstance(final_weight, (tuple, list)):
41
+ final_weight = final_weight[0]
42
+ self.initial_weight = float(initial_weight)
43
+ self.final_weight = float(final_weight)
44
+
45
+ # Ensure transition_step is an int scalar, not tuple/list
46
+ self.transition_step = ensure_int(transition_step)
47
+ if self.mode == "constant":
48
+ # For constant mode, transition_step is not used
49
+ self.initial_weight = self.final_weight
50
+ # For cyclic mode: ensure all weights are float scalars
51
+ if weight_list is not None:
52
+ self.weight_list = [
53
+ float(w[0]) if isinstance(w, (tuple, list)) else float(w)
54
+ for w in weight_list
55
+ ]
56
+ else:
57
+ self.weight_list = [self.initial_weight, self.final_weight]
58
+
59
+ # For switch_alternate mode
60
+ self.switch_period = int(switch_period)
61
+
62
+ def on_step_end(self, args, state, control, **kwargs):
63
+ current_step = state.global_step
64
+ transition_step = self.transition_step
65
+
66
+ # Ensure current_step and transition_step are plain ints
67
+ current_step = ensure_int(current_step)
68
+ transition_step = ensure_int(transition_step)
69
+
70
+ # Select mag_weight or enable/disable linear attention according to mode
71
+ if self.mode == "constant":
72
+ # Set mag_weight to final_weight for constant mode
73
+ weight = self.final_weight
74
+ for _, module in self.model.named_modules():
75
+ if isinstance(module, LiZAttention):
76
+ module.mag_weight = weight
77
+
78
+ elif self.mode == "gradual":
79
+ if current_step <= transition_step:
80
+ weight = self.initial_weight + (
81
+ self.final_weight - self.initial_weight
82
+ ) * (current_step / transition_step)
83
+ else:
84
+ weight = self.final_weight
85
+ for _, module in self.model.named_modules():
86
+ if isinstance(module, LiZAttention):
87
+ module.mag_weight = weight
88
+
89
+ elif self.mode == "cyclic":
90
+ idx = current_step % len(self.weight_list)
91
+ weight = self.weight_list[idx]
92
+ for _, module in self.model.named_modules():
93
+ if isinstance(module, LiZAttention):
94
+ module.mag_weight = weight
95
+
96
+ elif self.mode == "switch":
97
+ # Alternately enable/disable linear attention every switch_period steps
98
+ disable = (current_step // self.switch_period) % 2 == 0
99
+ for _, module in self.model.named_modules():
100
+ if isinstance(module, LiZAttention):
101
+ module.disable_linear_attn = disable
102
+
103
+ else:
104
+ raise ValueError(f"Unknown mode: {self.mode}")
105
+
106
+ def on_log(self, args, state, control, logs=None, **kwargs):
107
+ mag_weight = None
108
+ disable_linear_attn = None
109
+ # Log the current mag_weight and disable_linear_attn
110
+ for _, module in self.model.named_modules():
111
+ if isinstance(module, LiZAttention):
112
+ mag_weight = getattr(module, "mag_weight", None)
113
+ disable_linear_attn = getattr(module, "disable_linear_attn", None)
114
+ break
115
+ if mag_weight is not None and logs is not None:
116
+ logs["mag_weight"] = float(mag_weight)
117
+ if disable_linear_attn is not None and logs is not None:
118
+ logs["disable_linear_attn"] = bool(disable_linear_attn)
119
+
120
+
121
+ def ensure_int(value: Union[int, tuple, list]) -> int:
122
+ """Ensure the value is a plain integer."""
123
+ if isinstance(value, (tuple, list)):
124
+ value = int(value[0])
125
+ if hasattr(value, "item"):
126
+ value = int(value.item())
127
+ return value
128
+
129
+
130
+ class SaveBestModelCallback(TrainerCallback):
131
+ """TrainerCallback to save the best model based on evaluation loss."""
132
+
133
+ def __init__(self):
134
+ self.best_metric = float("inf")
135
+
136
+ def on_evaluate(self, args, state, control, metrics=None, **kwargs):
137
+ if metrics is not None and "eval_loss" in metrics:
138
+ if metrics["eval_loss"] < self.best_metric:
139
+ self.best_metric = metrics["eval_loss"]
140
+ control.should_save = True # Trigger save
141
+ else:
142
+ control.should_save = False # Skip save