littlebird13 commited on
Commit
976bece
·
verified ·
1 Parent(s): 775fc2e

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -1,14 +1,51 @@
1
  ---
2
- title: Qwen3 ASR
3
- emoji: 🌖
4
- colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 6.4.0
8
- python_version: '3.12'
9
  app_file: app.py
10
  pinned: false
11
  license: apache-2.0
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Qwen3-ASR Demo
3
+ emoji: 🎙️
4
+ colorFrom: blue
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.33.0
 
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
  ---
12
 
13
+ # Qwen3-ASR Demo
14
+
15
+ This Space demonstrates **Qwen3-ASR-1.7B**, a state-of-the-art automatic speech recognition model from the Qwen team, powered by **vLLM** for high-speed inference.
16
+
17
+ ## Features
18
+
19
+ - **30+ Language Support**: Chinese, Cantonese, English, Japanese, Korean, Arabic, German, French, Spanish, Portuguese, and many more
20
+ - **Word/Character-level Timestamps**: Accurate timestamp alignment for each word (English) or character (Chinese)
21
+ - **Interactive Visualization**: Click on each word/character to hear the corresponding audio segment
22
+ - **vLLM Backend**: Fast inference speed for real-time transcription
23
+
24
+ ## How to Use
25
+
26
+ 1. Upload an audio file or record using your microphone
27
+ 2. Select a language or leave "Auto" for automatic detection
28
+ 3. Enable "Timestamps" for visualization (recommended)
29
+ 4. Click "Transcribe" and see the results
30
+
31
+ ## Models Used
32
+
33
+ - **ASR Model**: [Qwen/Qwen3-ASR-1.7B](https://huggingface.co/Qwen/Qwen3-ASR-1.7B)
34
+ - **Forced Aligner**: [Qwen/Qwen3-ForcedAligner-0.6B](https://huggingface.co/Qwen/Qwen3-ForcedAligner-0.6B)
35
+
36
+ ## Setup (For Space Owners)
37
+
38
+ This Space requires access to private models. You need to set up the `HF_TOKEN` secret:
39
+
40
+ 1. Go to your Space Settings
41
+ 2. Navigate to "Repository secrets"
42
+ 3. Add a new secret with name `HF_TOKEN` and your Hugging Face access token as the value
43
+
44
+ ## Links
45
+
46
+ - [GitHub Repository](https://github.com/Qwen/Qwen3-ASR)
47
+ - [Model Card](https://huggingface.co/Qwen/Qwen3-ASR-1.7B)
48
+
49
+ ## License
50
+
51
+ Apache 2.0
app.py CHANGED
@@ -371,4 +371,4 @@ This demo showcases the 1.7B model which provides excellent multilingual recogni
371
 
372
 
373
  if __name__ == "__main__":
374
- demo.launch()
 
371
 
372
 
373
  if __name__ == "__main__":
374
+ demo.queue(default_concurrency_limit=4).launch()
qwen_asr/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Alibaba Qwen team.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ qwen_asr: Qwen3-ASR package.
18
+ """
19
+
20
+ from .inference.qwen3_asr import Qwen3ASRModel
21
+ from .inference.qwen3_forced_aligner import Qwen3ForcedAligner
22
+
23
+ from .inference.utils import parse_asr_output
24
+
25
+ __all__ = ["__version__"]
qwen_asr/__main__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Alibaba Qwen team.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ def main():
17
+ print(
18
+ "qwen_asr package.\n"
19
+ "Use CLI entrypoints:\n"
20
+ " - qwen-asr-demo\n"
21
+ " - qwen-asr-serve\n"
22
+ )
23
+
24
+ if __name__ == "__main__":
25
+ main()
qwen_asr/cli/demo.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Alibaba Qwen team.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ """
5
+ A gradio demo for Qwen3 ASR models.
6
+ """
7
+
8
+ import argparse
9
+ import base64
10
+ import io
11
+ import json
12
+ import os
13
+ from typing import Any, Dict, List, Optional, Tuple, Union
14
+
15
+ import gradio as gr
16
+ import numpy as np
17
+ import torch
18
+ from qwen_asr import Qwen3ASRModel
19
+ from scipy.io.wavfile import write as wav_write
20
+
21
+
22
+ def _title_case_display(s: str) -> str:
23
+ s = (s or "").strip()
24
+ s = s.replace("_", " ")
25
+ return " ".join([w[:1].upper() + w[1:] if w else "" for w in s.split()])
26
+
27
+
28
+ def _build_choices_and_map(items: Optional[List[str]]) -> Tuple[List[str], Dict[str, str]]:
29
+ if not items:
30
+ return [], {}
31
+ display = [_title_case_display(x) for x in items]
32
+ mapping = {d: r for d, r in zip(display, items)}
33
+ return display, mapping
34
+
35
+
36
+ def _dtype_from_str(s: str) -> torch.dtype:
37
+ s = (s or "").strip().lower()
38
+ if s in ("bf16", "bfloat16"):
39
+ return torch.bfloat16
40
+ if s in ("fp16", "float16", "half"):
41
+ return torch.float16
42
+ if s in ("fp32", "float32"):
43
+ return torch.float32
44
+ raise ValueError(f"Unsupported torch dtype: {s}. Use bfloat16/float16/float32.")
45
+
46
+
47
+ def _normalize_audio(wav, eps=1e-12, clip=True):
48
+ x = np.asarray(wav)
49
+
50
+ if np.issubdtype(x.dtype, np.integer):
51
+ info = np.iinfo(x.dtype)
52
+ if info.min < 0:
53
+ y = x.astype(np.float32) / max(abs(info.min), info.max)
54
+ else:
55
+ mid = (info.max + 1) / 2.0
56
+ y = (x.astype(np.float32) - mid) / mid
57
+ elif np.issubdtype(x.dtype, np.floating):
58
+ y = x.astype(np.float32)
59
+ m = np.max(np.abs(y)) if y.size else 0.0
60
+ if m > 1.0 + 1e-6:
61
+ y = y / (m + eps)
62
+ else:
63
+ raise TypeError(f"Unsupported dtype: {x.dtype}")
64
+
65
+ if clip:
66
+ y = np.clip(y, -1.0, 1.0)
67
+
68
+ if y.ndim > 1:
69
+ y = np.mean(y, axis=-1).astype(np.float32)
70
+
71
+ return y
72
+
73
+
74
+ def _audio_to_tuple(audio: Any) -> Optional[Tuple[np.ndarray, int]]:
75
+ """
76
+ Accept gradio audio:
77
+ - {"sampling_rate": int, "data": np.ndarray}
78
+ - (sr, np.ndarray) [some gradio versions]
79
+ Return: (wav_float32_mono, sr)
80
+ """
81
+ if audio is None:
82
+ return None
83
+
84
+ if isinstance(audio, dict) and "sampling_rate" in audio and "data" in audio:
85
+ sr = int(audio["sampling_rate"])
86
+ wav = _normalize_audio(audio["data"])
87
+ return wav, sr
88
+
89
+ if isinstance(audio, tuple) and len(audio) == 2:
90
+ a0, a1 = audio
91
+ if isinstance(a0, int):
92
+ sr = int(a0)
93
+ wav = _normalize_audio(a1)
94
+ return wav, sr
95
+ if isinstance(a1, int):
96
+ wav = _normalize_audio(a0)
97
+ sr = int(a1)
98
+ return wav, sr
99
+
100
+ return None
101
+
102
+
103
+ def _parse_audio_any(audio: Any) -> Union[str, Tuple[np.ndarray, int]]:
104
+ if audio is None:
105
+ raise ValueError("Audio is required.")
106
+ at = _audio_to_tuple(audio)
107
+ if at is not None:
108
+ return at
109
+ raise ValueError("Unsupported audio input format.")
110
+
111
+
112
+ def build_parser() -> argparse.ArgumentParser:
113
+ parser = argparse.ArgumentParser(
114
+ prog="qwen-asr-demo",
115
+ description=(
116
+ "Launch a Gradio demo for Qwen3 ASR models (Transformers / vLLM).\n\n"
117
+ "Examples:\n"
118
+ " qwen-asr-demo --asr-checkpoint Qwen/Qwen3-ASR-1.7B\n"
119
+ " qwen-asr-demo --asr-checkpoint Qwen/Qwen3-ASR-1.7B --aligner-checkpoint Qwen/Qwen3-ForcedAligner-0.6B\n"
120
+ " qwen-asr-demo --backend vllm --cuda-visible-devices 0\n"
121
+ " qwen-asr-demo --backend transformers --backend-kwargs '{\"device_map\":\"cuda:0\",\"dtype\":\"bfloat16\",\"attn_implementation\":\"flash_attention_2\"}'\n"
122
+ " qwen-asr-demo --backend vllm --backend-kwargs '{\"gpu_memory_utilization\":0.85}'\n"
123
+ ),
124
+ formatter_class=argparse.RawTextHelpFormatter,
125
+ add_help=True,
126
+ )
127
+
128
+ parser.add_argument("--asr-checkpoint", required=True, help="Qwen3-ASR model checkpoint path or HF repo id.")
129
+ parser.add_argument(
130
+ "--aligner-checkpoint",
131
+ default=None,
132
+ help="Qwen3-ForcedAligner checkpoint path or HF repo id (optional; enables timestamps when provided).",
133
+ )
134
+
135
+ parser.add_argument(
136
+ "--backend",
137
+ default="transformers",
138
+ choices=["transformers", "vllm"],
139
+ help="Backend for ASR model loading (default: transformers).",
140
+ )
141
+
142
+ parser.add_argument(
143
+ "--cuda-visible-devices",
144
+ default="0",
145
+ help=(
146
+ "Set CUDA_VISIBLE_DEVICES for the demo process (default: 0). "
147
+ "Use e.g. '0' or '1'"
148
+ ),
149
+ )
150
+
151
+ parser.add_argument(
152
+ "--backend-kwargs",
153
+ default=None,
154
+ help=(
155
+ "JSON dict for backend-specific kwargs excluding checkpoints.\n"
156
+ "Examples:\n"
157
+ " transformers: '{\"device_map\":\"cuda:0\",\"dtype\":\"bfloat16\",\"attn_implementation\":\"flash_attention_2\",\"max_inference_batch_size\":32}'\n"
158
+ " vllm : '{\"gpu_memory_utilization\":0.8,\"max_inference_batch_size\":32}'\n"
159
+ ),
160
+ )
161
+ parser.add_argument(
162
+ "--aligner-kwargs",
163
+ default=None,
164
+ help=(
165
+ "JSON dict for forced aligner kwargs (only used when --aligner-checkpoint is set).\n"
166
+ "Example: '{\"dtype\":\"bfloat16\",\"device_map\":\"cuda:0\"}'\n"
167
+ ),
168
+ )
169
+
170
+ # Gradio server args
171
+ parser.add_argument("--ip", default="0.0.0.0", help="Server bind IP for Gradio (default: 0.0.0.0).")
172
+ parser.add_argument("--port", type=int, default=8000, help="Server port for Gradio (default: 8000).")
173
+ parser.add_argument(
174
+ "--share/--no-share",
175
+ dest="share",
176
+ default=False,
177
+ action=argparse.BooleanOptionalAction,
178
+ help="Whether to create a public Gradio link (default: disabled).",
179
+ )
180
+ parser.add_argument("--concurrency", type=int, default=16, help="Gradio queue concurrency (default: 16).")
181
+
182
+ # HTTPS args
183
+ parser.add_argument("--ssl-certfile", default=None, help="Path to SSL certificate file for HTTPS (optional).")
184
+ parser.add_argument("--ssl-keyfile", default=None, help="Path to SSL key file for HTTPS (optional).")
185
+ parser.add_argument(
186
+ "--ssl-verify/--no-ssl-verify",
187
+ dest="ssl_verify",
188
+ default=True,
189
+ action=argparse.BooleanOptionalAction,
190
+ help="Whether to verify SSL certificate (default: enabled).",
191
+ )
192
+
193
+ return parser
194
+
195
+
196
+ def _parse_json_dict(s: Optional[str], *, name: str) -> Dict[str, Any]:
197
+ if s is None or not str(s).strip():
198
+ return {}
199
+ try:
200
+ obj = json.loads(s)
201
+ except Exception as e:
202
+ raise ValueError(f"Invalid JSON for {name}: {e}")
203
+ if not isinstance(obj, dict):
204
+ raise ValueError(f"{name} must be a JSON object (dict).")
205
+ return obj
206
+
207
+
208
+ def _apply_cuda_visible_devices(cuda_visible_devices: str) -> None:
209
+ v = (cuda_visible_devices or "").strip()
210
+ if not v:
211
+ return
212
+ os.environ["CUDA_VISIBLE_DEVICES"] = v
213
+
214
+
215
+ def _default_backend_kwargs(backend: str) -> Dict[str, Any]:
216
+ if backend == "transformers":
217
+ return dict(
218
+ dtype=torch.bfloat16,
219
+ device_map="cuda:0",
220
+ attn_implementation="flash_attention_2",
221
+ max_inference_batch_size=32,
222
+ )
223
+ else:
224
+ return dict(
225
+ gpu_memory_utilization=0.8,
226
+ max_inference_batch_size=32,
227
+ )
228
+
229
+
230
+ def _default_aligner_kwargs() -> Dict[str, Any]:
231
+ return dict(
232
+ dtype=torch.bfloat16,
233
+ device_map="cuda:0",
234
+ )
235
+
236
+
237
+ def _merge_dicts(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
238
+ out = dict(base)
239
+ out.update(override)
240
+ return out
241
+
242
+
243
+ def _coerce_special_types(d: Dict[str, Any]) -> Dict[str, Any]:
244
+ out: Dict[str, Any] = {}
245
+ for k, v in d.items():
246
+ if k == "dtype" and isinstance(v, str):
247
+ out[k] = _dtype_from_str(v)
248
+ else:
249
+ out[k] = v
250
+ return out
251
+
252
+
253
+ def _make_timestamp_html(audio_upload: Any, timestamps: Any) -> str:
254
+ """
255
+ Build HTML with per-token audio slices, using base64 data URLs (no filesystem caching).
256
+ Expect timestamps as list[dict] with keys: text, start_time, end_time (ms).
257
+ """
258
+ at = _audio_to_tuple(audio_upload)
259
+ if at is None:
260
+ raise ValueError("Audio input is required for visualization.")
261
+ audio, sr = at
262
+
263
+ if not timestamps:
264
+ return "<div style='color:#666'>No timestamps to visualize.</div>"
265
+ if not isinstance(timestamps, list):
266
+ raise ValueError("Timestamps must be a list (JSON array).")
267
+
268
+ html_content = """
269
+ <style>
270
+ .word-alignment-container { display: flex; flex-wrap: wrap; gap: 10px; }
271
+ .word-box {
272
+ border: 1px solid #ddd; border-radius: 8px; padding: 10px;
273
+ background-color: #f9f9f9; box-shadow: 0 2px 4px rgba(0,0,0,0.06);
274
+ text-align: center;
275
+ }
276
+ .word-text { font-size: 18px; font-weight: 700; margin-bottom: 5px; }
277
+ .word-time { font-size: 12px; color: #666; margin-bottom: 8px; }
278
+ .word-audio audio { width: 140px; height: 30px; }
279
+ details { border: 1px solid #ddd; border-radius: 6px; padding: 10px; background-color: #f7f7f7; }
280
+ summary { font-weight: 700; cursor: pointer; }
281
+ </style>
282
+ """
283
+
284
+ html_content += """
285
+ <details open>
286
+ <summary>Timestamps Visualization (时间戳可视化结果)</summary>
287
+ <div class="word-alignment-container" style="margin-top: 14px;">
288
+ """
289
+
290
+ for item in timestamps:
291
+ if not isinstance(item, dict):
292
+ continue
293
+ word = str(item.get("text", "") or "")
294
+ start = item.get("start_time", None)
295
+ end = item.get("end_time", None)
296
+ if start is None or end is None:
297
+ continue
298
+
299
+ start = float(start)
300
+ end = float(end)
301
+ if end <= start:
302
+ continue
303
+
304
+ start_sample = max(0, int(start * sr))
305
+ end_sample = min(len(audio), int(end * sr))
306
+ if end_sample <= start_sample:
307
+ continue
308
+
309
+ seg = audio[start_sample:end_sample]
310
+ seg_i16 = (np.clip(seg, -1.0, 1.0) * 32767.0).astype(np.int16)
311
+
312
+ mem = io.BytesIO()
313
+ wav_write(mem, sr, seg_i16)
314
+ mem.seek(0)
315
+ b64 = base64.b64encode(mem.read()).decode("utf-8")
316
+ audio_src = f"data:audio/wav;base64,{b64}"
317
+
318
+ html_content += f"""
319
+ <div class="word-box">
320
+ <div class="word-text">{word}</div>
321
+ <div class="word-time">{start} - {end} s</div>
322
+ <div class="word-audio">
323
+ <audio controls preload="none" src="{audio_src}"></audio>
324
+ </div>
325
+ </div>
326
+ """
327
+
328
+ html_content += "</div></details>"
329
+ return html_content
330
+
331
+
332
+ def build_demo(
333
+ asr: Qwen3ASRModel,
334
+ asr_ckpt: str,
335
+ backend: str,
336
+ aligner_ckpt: Optional[str] = None,
337
+ ) -> gr.Blocks:
338
+ supported_langs_raw = asr.get_supported_languages()
339
+ lang_choices_disp, lang_map = _build_choices_and_map([x for x in supported_langs_raw])
340
+ lang_choices = ["Auto"] + lang_choices_disp
341
+
342
+ has_aligner = bool(aligner_ckpt)
343
+
344
+ theme = gr.themes.Soft(
345
+ font=[gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"],
346
+ )
347
+ css = ".gradio-container {max-width: none !important;}"
348
+
349
+ with gr.Blocks(theme=theme, css=css) as demo:
350
+ gr.Markdown(
351
+ f"""
352
+ # Qwen3 ASR Demo
353
+ **Backend:** `{backend}`
354
+ **ASR Checkpoint:** `{asr_ckpt}`
355
+ **Forced Aligner:** `{aligner_ckpt if aligner_ckpt else "(none)"}`
356
+ """
357
+ )
358
+
359
+ with gr.Row():
360
+ with gr.Column(scale=2):
361
+ audio_in = gr.Audio(label="Audio Input (上传音频)", type="numpy")
362
+ lang_in = gr.Dropdown(
363
+ label="Language (语种)",
364
+ choices=lang_choices,
365
+ value="Auto",
366
+ interactive=True,
367
+ )
368
+ if has_aligner:
369
+ ts_in = gr.Checkbox(
370
+ label="Return Timestamps (是否返回时间戳)",
371
+ value=True,
372
+ )
373
+ else:
374
+ ts_in = gr.State(False)
375
+
376
+ btn = gr.Button("Transcribe (识别)", variant="primary")
377
+
378
+ with gr.Column(scale=2):
379
+ out_lang = gr.Textbox(label="Detected Language", lines=1)
380
+ out_text = gr.Textbox(label="Result Text", lines=12)
381
+
382
+ if has_aligner:
383
+ with gr.Column(scale=3):
384
+ out_ts = gr.JSON(label="Timestamps(时间戳结果)")
385
+ viz_btn = gr.Button("Visualize Timestamps (可视化时间戳)", variant="secondary")
386
+ else:
387
+ with gr.Column(scale=3):
388
+ out_ts = gr.State(None)
389
+ viz_btn = gr.State(None)
390
+
391
+ # Put the visualization panel below the three columns
392
+ if has_aligner:
393
+ with gr.Row():
394
+ out_ts_html = gr.HTML(label="Timestamps Visualization (时间戳可视化结果)")
395
+ else:
396
+ out_ts_html = gr.State("")
397
+
398
+ def run(audio_upload: Any, lang_disp: str, return_ts: bool):
399
+ audio_obj = _parse_audio_any(audio_upload)
400
+
401
+ language = None
402
+ if lang_disp and lang_disp != "Auto":
403
+ language = lang_map.get(lang_disp, lang_disp)
404
+
405
+ return_ts = bool(return_ts) and has_aligner
406
+
407
+ results = asr.transcribe(
408
+ audio=audio_obj,
409
+ language=language,
410
+ return_time_stamps=return_ts,
411
+ )
412
+ if not isinstance(results, list) or len(results) != 1:
413
+ raise RuntimeError(
414
+ f"Unexpected result size: {type(results)} "
415
+ f"len={len(results) if isinstance(results, list) else 'N/A'}"
416
+ )
417
+
418
+ r = results[0]
419
+
420
+ if has_aligner:
421
+ ts_payload = None
422
+ if return_ts:
423
+ ts_payload = [
424
+ dict(
425
+ text=getattr(t, "text", None),
426
+ start_time=getattr(t, "start_time", None),
427
+ end_time=getattr(t, "end_time", None),
428
+ )
429
+ for t in (getattr(r, "time_stamps", None) or [])
430
+ ]
431
+ return (
432
+ getattr(r, "language", "") or "",
433
+ getattr(r, "text", "") or "",
434
+ gr.update(value=ts_payload) if return_ts else gr.update(value=None),
435
+ gr.update(value=""), # clear html on each transcribe
436
+ )
437
+ else:
438
+ return (
439
+ getattr(r, "language", "") or "",
440
+ getattr(r, "text", "") or "",
441
+ )
442
+
443
+ def visualize(audio_upload: Any, timestamps_json: Any):
444
+ return _make_timestamp_html(audio_upload, timestamps_json)
445
+
446
+ if has_aligner:
447
+ btn.click(
448
+ run,
449
+ inputs=[audio_in, lang_in, ts_in],
450
+ outputs=[out_lang, out_text, out_ts, out_ts_html],
451
+ )
452
+ viz_btn.click(
453
+ visualize,
454
+ inputs=[audio_in, out_ts],
455
+ outputs=[out_ts_html],
456
+ )
457
+ else:
458
+ btn.click(
459
+ run,
460
+ inputs=[audio_in, lang_in, ts_in],
461
+ outputs=[out_lang, out_text],
462
+ )
463
+
464
+ return demo
465
+
466
+
467
+ def main(argv=None) -> int:
468
+ parser = build_parser()
469
+ args = parser.parse_args(argv)
470
+
471
+ _apply_cuda_visible_devices(args.cuda_visible_devices)
472
+
473
+ backend = args.backend
474
+ asr_ckpt = args.asr_checkpoint
475
+ aligner_ckpt = args.aligner_checkpoint
476
+
477
+ user_backend_kwargs = _parse_json_dict(args.backend_kwargs, name="--backend-kwargs")
478
+ user_aligner_kwargs = _parse_json_dict(args.aligner_kwargs, name="--aligner-kwargs")
479
+
480
+ backend_kwargs = _merge_dicts(_default_backend_kwargs(backend), user_backend_kwargs)
481
+ backend_kwargs = _coerce_special_types(backend_kwargs)
482
+
483
+ forced_aligner = None
484
+ forced_aligner_kwargs = None
485
+ if aligner_ckpt:
486
+ forced_aligner = aligner_ckpt
487
+ aligner_kwargs = _merge_dicts(_default_aligner_kwargs(), user_aligner_kwargs)
488
+ forced_aligner_kwargs = _coerce_special_types(aligner_kwargs)
489
+
490
+ if backend == "transformers":
491
+ asr = Qwen3ASRModel.from_pretrained(
492
+ asr_ckpt,
493
+ forced_aligner=forced_aligner,
494
+ forced_aligner_kwargs=forced_aligner_kwargs,
495
+ **backend_kwargs,
496
+ )
497
+ else:
498
+ asr = Qwen3ASRModel.LLM(
499
+ model=asr_ckpt,
500
+ forced_aligner=forced_aligner,
501
+ forced_aligner_kwargs=forced_aligner_kwargs,
502
+ **backend_kwargs,
503
+ )
504
+
505
+ demo = build_demo(asr, asr_ckpt, backend, aligner_ckpt=aligner_ckpt)
506
+
507
+ launch_kwargs: Dict[str, Any] = dict(
508
+ server_name=args.ip,
509
+ server_port=args.port,
510
+ share=args.share,
511
+ ssl_verify=True if args.ssl_verify else False,
512
+ )
513
+ if args.ssl_certfile is not None:
514
+ launch_kwargs["ssl_certfile"] = args.ssl_certfile
515
+ if args.ssl_keyfile is not None:
516
+ launch_kwargs["ssl_keyfile"] = args.ssl_keyfile
517
+
518
+ demo.queue(default_concurrency_limit=int(args.concurrency)).launch(**launch_kwargs)
519
+ return 0
520
+
521
+
522
+ if __name__ == "__main__":
523
+ raise SystemExit(main())
qwen_asr/cli/serve.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Alibaba Qwen team.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import sys
17
+
18
+ from qwen_asr.core.transformers_backend import (
19
+ Qwen3ASRConfig,
20
+ Qwen3ASRForConditionalGeneration,
21
+ Qwen3ASRProcessor,
22
+ )
23
+ from transformers import AutoConfig, AutoModel, AutoProcessor
24
+
25
+ AutoConfig.register("qwen3_asr", Qwen3ASRConfig)
26
+ AutoModel.register(Qwen3ASRConfig, Qwen3ASRForConditionalGeneration)
27
+ AutoProcessor.register(Qwen3ASRConfig, Qwen3ASRProcessor)
28
+
29
+ try:
30
+ from qwen_asr.core.vllm_backend import Qwen3ASRForConditionalGeneration
31
+ from vllm import ModelRegistry
32
+ ModelRegistry.register_model("Qwen3ASRForConditionalGeneration", Qwen3ASRForConditionalGeneration)
33
+ except Exception as e:
34
+ raise ImportError(
35
+ "vLLM is not available, to use qwen-asr-serve, please install with: pip install qwen-asr[vllm]"
36
+ ) from e
37
+
38
+ from vllm.entrypoints.cli.main import main as vllm_main
39
+
40
+ def main():
41
+ sys.argv.insert(1, "serve")
42
+ vllm_main()
43
+
44
+
45
+ if __name__ == "__main__":
46
+ main()
qwen_asr/core/transformers_backend/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Alibaba Qwen team.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ from .configuration_qwen3_asr import Qwen3ASRConfig
17
+ from .modeling_qwen3_asr import Qwen3ASRForConditionalGeneration
18
+ from .processing_qwen3_asr import Qwen3ASRProcessor
qwen_asr/core/transformers_backend/configuration_qwen3_asr.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from transformers.configuration_utils import PretrainedConfig
16
+ from transformers.utils import logging
17
+
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+
22
+ class Qwen3ASRAudioEncoderConfig(PretrainedConfig):
23
+ r"""
24
+ This is the configuration class to store the configuration of a [`Qwen3ASRAudioEncoder`]. It is used to instantiate a
25
+ Qwen3-ASR audio encoder according to the specified arguments, defining the model architecture. Instantiating a
26
+ configuration with the defaults will yield a similar configuration to that of the audio encoder of the Qwen2-Audio
27
+ architecture.
28
+
29
+ e.g. [Qwen/Qwen3-ASR-1.7B](https://huggingface.co/Qwen/Qwen3-ASR-1.7B)
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+ Args:
35
+ num_mel_bins (`int`, *optional*, defaults to 128):
36
+ Number of mel features used per input features. Should correspond to the value used in the
37
+ `Qwen3ASRProcessor` class.
38
+ encoder_layers (`int`, *optional*, defaults to 32):
39
+ Number of encoder layers.
40
+ encoder_attention_heads (`int`, *optional*, defaults to 20):
41
+ Number of attention heads for each attention layer in the Transformer encoder.
42
+ encoder_ffn_dim (`int`, *optional*, defaults to 5120):
43
+ Dimensionality of the "intermediate" (often named feed-forward) layer in encoder.
44
+ d_model (`int`, *optional*, defaults to 1280):
45
+ Dimensionality of the layers.
46
+ dropout (`float`, *optional*, defaults to 0.0):
47
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
48
+ attention_dropout (`float`, *optional*, defaults to 0.0):
49
+ The dropout ratio for the attention probabilities.
50
+ activation_function (`str`, *optional*, defaults to `"gelu"`):
51
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
52
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
53
+ activation_dropout (`float`, *optional*, defaults to 0.0):
54
+ The dropout ratio for activations inside the fully connected layer.
55
+ scale_embedding (`bool`, *optional*, defaults to `False`):
56
+ Scale embeddings by diving by sqrt(d_model).
57
+ initializer_range (`float`, *optional*, defaults to 0.02):
58
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
59
+ max_source_positions (`int`, *optional*, defaults to 1500):
60
+ The maximum sequence length of log-mel filter-bank features that this model might ever be used with.
61
+ n_window (`int`, *optional*, defaults to 100):
62
+ The chunk for conv and flash attn in AudioEncoder.
63
+ output_dim (`int`, *optional*, defaults to 3584):
64
+ The output dimension of AudioEncoder.
65
+
66
+ Example:
67
+
68
+ ```python
69
+ >>> from transformers import Qwen3ASRAudioEncoderConfig, Qwen3ASRAudioEncoder
70
+
71
+ >>> # Initializing a Qwen3ASRAudioEncoderConfig
72
+ >>> configuration = Qwen3ASRAudioEncoderConfig()
73
+
74
+ >>> # Initializing a Qwen3ASRAudioEncoder (with random weights)
75
+ >>> model = Qwen3ASRAudioEncoder(configuration)
76
+
77
+ >>> # Accessing the model configuration
78
+ >>> configuration = model.config
79
+ ```"""
80
+
81
+ model_type = "qwen3_asr_audio_encoder"
82
+
83
+ def __init__(
84
+ self,
85
+ num_mel_bins=128,
86
+ encoder_layers=32,
87
+ encoder_attention_heads=20,
88
+ encoder_ffn_dim=5120,
89
+ d_model=1280,
90
+ dropout=0,
91
+ attention_dropout=0,
92
+ activation_function="gelu",
93
+ activation_dropout=0,
94
+ scale_embedding=False,
95
+ initializer_range=0.02,
96
+ max_source_positions=1500,
97
+ n_window=100,
98
+ output_dim=3584,
99
+ n_window_infer=400,
100
+ conv_chunksize=500,
101
+ downsample_hidden_size=480,
102
+ **kwargs,
103
+ ):
104
+ super().__init__(**kwargs)
105
+
106
+ self.num_mel_bins = num_mel_bins
107
+ self.d_model = d_model
108
+ self.encoder_layers = encoder_layers
109
+ self.encoder_attention_heads = encoder_attention_heads
110
+ self.encoder_ffn_dim = encoder_ffn_dim
111
+ self.dropout = dropout
112
+ self.attention_dropout = attention_dropout
113
+ self.activation_function = activation_function
114
+ self.activation_dropout = activation_dropout
115
+ self.num_hidden_layers = encoder_layers
116
+ self.initializer_range = initializer_range
117
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
118
+ self.max_source_positions = max_source_positions
119
+ self.n_window = n_window
120
+ self.output_dim = output_dim
121
+ self.n_window_infer = n_window_infer
122
+ self.conv_chunksize = conv_chunksize
123
+ self.downsample_hidden_size = downsample_hidden_size
124
+
125
+
126
+ class Qwen3ASRTextConfig(PretrainedConfig):
127
+ r"""
128
+ This is the configuration class to store the configuration of a [`Qwen3ASRTextModel`]. It is used to instantiate a
129
+ Qwen3-ASR model according to the specified arguments, defining the model architecture. Instantiating a configuration
130
+ with the defaults will yield a similar configuration to that of
131
+ Qwen3-ASR-1.7B [Qwen/Qwen3-ASR-1.7B](https://huggingface.co/Qwen/Qwen3-ASR-1.7B)
132
+
133
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
134
+ documentation from [`PretrainedConfig`] for more information.
135
+
136
+ Args:
137
+ vocab_size (`int`, *optional*, defaults to 151936):
138
+ Vocabulary size of the Qwen3ASR model. Defines the number of different tokens that can be represented by the
139
+ `inputs_ids` passed when calling [`Qwen3ASRModel`]
140
+ hidden_size (`int`, *optional*, defaults to 4096):
141
+ Dimension of the hidden representations.
142
+ intermediate_size (`int`, *optional*, defaults to 22016):
143
+ Dimension of the MLP representations.
144
+ num_hidden_layers (`int`, *optional*, defaults to 32):
145
+ Number of hidden layers in the Transformer encoder.
146
+ num_attention_heads (`int`, *optional*, defaults to 32):
147
+ Number of attention heads for each attention layer in the Transformer encoder.
148
+ num_key_value_heads (`int`, *optional*, defaults to 32):
149
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
150
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
151
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
152
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
153
+ by meanpooling all the original heads within that group. For more details, check out [this
154
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
155
+ head_dim (`int`, *optional*, defaults to 128):
156
+ The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`.
157
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
158
+ The non-linear activation function (function or string) in the decoder.
159
+ max_position_embeddings (`int`, *optional*, defaults to 128000):
160
+ The maximum sequence length that this model might ever be used with.
161
+ initializer_range (`float`, *optional*, defaults to 0.02):
162
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
163
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
164
+ The epsilon used by the rms normalization layers.
165
+ use_cache (`bool`, *optional*, defaults to `True`):
166
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
167
+ relevant if `config.is_decoder=True`.
168
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
169
+ Whether the model's input and output word embeddings should be tied.
170
+ rope_theta (`float`, *optional*, defaults to 5000000.0):
171
+ The base period of the RoPE embeddings.
172
+ rope_scaling (`Dict`, *optional*):
173
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
174
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
175
+ accordingly.
176
+ Expected contents:
177
+ `rope_type` (`str`):
178
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
179
+ 'llama3'], with 'default' being the original RoPE implementation.
180
+ `factor` (`float`, *optional*):
181
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
182
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
183
+ original maximum pre-trained length.
184
+ `original_max_position_embeddings` (`int`, *optional*):
185
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
186
+ pretraining.
187
+ `attention_factor` (`float`, *optional*):
188
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
189
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
190
+ `factor` field to infer the suggested value.
191
+ `beta_fast` (`float`, *optional*):
192
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
193
+ ramp function. If unspecified, it defaults to 32.
194
+ `beta_slow` (`float`, *optional*):
195
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
196
+ ramp function. If unspecified, it defaults to 1.
197
+ `short_factor` (`list[float]`, *optional*):
198
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
199
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
200
+ size divided by the number of attention heads divided by 2
201
+ `long_factor` (`list[float]`, *optional*):
202
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
203
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
204
+ size divided by the number of attention heads divided by 2
205
+ `low_freq_factor` (`float`, *optional*):
206
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
207
+ `high_freq_factor` (`float`, *optional*):
208
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
209
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
210
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
211
+ attention_dropout (`float`, *optional*, defaults to 0.0):
212
+ The dropout ratio for the attention probabilities.
213
+
214
+ ```python
215
+ >>> from transformers import Qwen3ASRTextModel, Qwen3ASRTextConfig
216
+
217
+ >>> # Initializing a Qwen3ASR style configuration
218
+ >>> configuration = Qwen3ASRTextConfig()
219
+
220
+ >>> # Initializing a model from the Qwen3-VL-7B style configuration
221
+ >>> model = Qwen3ASRTextModel(configuration)
222
+
223
+ >>> # Accessing the model configuration
224
+ >>> configuration = model.config
225
+ ```"""
226
+
227
+ model_type = "qwen3_asr_text"
228
+ base_config_key = "text_config"
229
+
230
+ def __init__(
231
+ self,
232
+ vocab_size=151936,
233
+ hidden_size=4096,
234
+ intermediate_size=22016,
235
+ num_hidden_layers=32,
236
+ num_attention_heads=32,
237
+ num_key_value_heads=32,
238
+ head_dim=128,
239
+ hidden_act="silu",
240
+ max_position_embeddings=128000,
241
+ initializer_range=0.02,
242
+ rms_norm_eps=1e-6,
243
+ use_cache=True,
244
+ tie_word_embeddings=False,
245
+ rope_theta=5000000.0,
246
+ rope_scaling=None,
247
+ attention_bias=False,
248
+ attention_dropout=0.0,
249
+ **kwargs,
250
+ ):
251
+ self.vocab_size = vocab_size
252
+ self.max_position_embeddings = max_position_embeddings
253
+ self.hidden_size = hidden_size
254
+ self.intermediate_size = intermediate_size
255
+ self.num_hidden_layers = num_hidden_layers
256
+ self.num_attention_heads = num_attention_heads
257
+
258
+ # for backward compatibility
259
+ if num_key_value_heads is None:
260
+ num_key_value_heads = num_attention_heads
261
+
262
+ self.num_key_value_heads = num_key_value_heads
263
+ self.head_dim = head_dim
264
+ self.hidden_act = hidden_act
265
+ self.initializer_range = initializer_range
266
+ self.rms_norm_eps = rms_norm_eps
267
+ self.use_cache = use_cache
268
+ self.rope_theta = rope_theta
269
+ self.rope_scaling = rope_scaling
270
+ self.attention_bias = attention_bias
271
+ self.attention_dropout = attention_dropout
272
+ # Validate the correctness of rotary position embeddings parameters
273
+ # BC: if there is a 'type' field, move it to 'rope_type'.
274
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
275
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
276
+
277
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
278
+
279
+
280
+ class Qwen3ASRThinkerConfig(PretrainedConfig):
281
+ r"""
282
+ This is the configuration class to store the configuration of a [`Qwen3ASRThinker`]. It is used to instantiate a
283
+ Qwen3-ASR-Thinker model according to the specified arguments, defining the model architecture. Instantiating a
284
+ configuration with the defaults will yield a similar configuration to that of the thinker component of the Qwen3-Omni
285
+ architecture.
286
+
287
+ e.g. [Qwen/Qwen3-ASR-1.7B](https://huggingface.co/Qwen/Qwen3-ASR-1.7B)
288
+
289
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
290
+ documentation from [`PretrainedConfig`] for more information.
291
+
292
+ Args:
293
+ audio_config (`dict`, *optional*):
294
+ The config dictionary of the audio backbone.
295
+ text_config (`dict`, *optional*):
296
+ The config dictionary of the text backbone.
297
+ audio_token_id (`int`, *optional*, defaults to 151646):
298
+ The audio token id to encode the audio prompt.
299
+ audio_start_token_id (`int`, *optional*, defaults to 151647):
300
+ The audio start token id to encode the audio prompt.
301
+ user_token_id (`int`, *optional*, defaults to 872):
302
+ The user token id to encode the user token.
303
+ initializer_range (`float`, *optional*, defaults to 0.02):
304
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
305
+
306
+ Example:
307
+
308
+ ```python
309
+ >>> from transformers import Qwen3ASRThinkerModel, Qwen3ASRThinkerConfig
310
+
311
+ >>> # Initializing a default Qwen3ASRThinkerConfig
312
+ >>> configuration = Qwen3ASRThinkerConfig()
313
+
314
+ >>> # Initializing a model (with random weights) from the default configuration
315
+ >>> model = Qwen3ASRThinkerModel(configuration)
316
+
317
+ >>> # Accessing the model configuration
318
+ >>> configuration = model.config
319
+ ```"""
320
+
321
+ model_type = "qwen3_asr_thinker"
322
+
323
+ attribute_map = {}
324
+ sub_configs = {
325
+ "audio_config": Qwen3ASRAudioEncoderConfig,
326
+ "text_config": Qwen3ASRTextConfig,
327
+ }
328
+
329
+ def __init__(
330
+ self,
331
+ audio_config=None,
332
+ text_config=None,
333
+ audio_token_id=151646,
334
+ audio_start_token_id=151647,
335
+ user_token_id=872,
336
+ initializer_range=0.02,
337
+ **kwargs,
338
+ ):
339
+ super().__init__(**kwargs)
340
+ self.user_token_id = user_token_id
341
+ self.audio_start_token_id = audio_start_token_id
342
+ self.initializer_range = initializer_range
343
+
344
+ if isinstance(audio_config, dict):
345
+ audio_config = Qwen3ASRAudioEncoderConfig(**audio_config)
346
+ elif audio_config is None:
347
+ audio_config = Qwen3ASRAudioEncoderConfig()
348
+ self.audio_config = audio_config
349
+
350
+ if isinstance(text_config, dict):
351
+ text_config = Qwen3ASRTextConfig(**text_config)
352
+ elif text_config is None:
353
+ text_config = Qwen3ASRTextConfig()
354
+ self.text_config = text_config
355
+ self.audio_token_id = audio_token_id
356
+
357
+
358
+ class Qwen3ASRConfig(PretrainedConfig):
359
+ """
360
+ This is the configuration class to store the configuration of a [`Qwen3ASRForConditionalGeneration`]. It is used to instantiate a Qwen3ASR
361
+ model according to the specified sub-models configurations, defining the model architecture.
362
+
363
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the
364
+ [Qwen/Qwen3-ASR-1.7B](https://huggingface.co/Qwen/Qwen3-ASR-1.7B) architecture.
365
+
366
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
367
+ documentation from [`PretrainedConfig`] for more information.
368
+
369
+ Args:
370
+ thinker_config (`dict`, *optional*): Configuration of the underlying thinker sub-model.
371
+ support_languages (`List[str]`, *optional*): The languages supported by the model.
372
+
373
+ Example:
374
+
375
+ ```python
376
+ >>> from transformers import (
377
+ ... Qwen3ASRThinkerConfig,
378
+ ... Qwen3ASRForConditionalGeneration,
379
+ ... Qwen3ASRConfig,
380
+ ... )
381
+
382
+ >>> # Initializing a Qwen3ASR style configuration
383
+ >>> configuration = Qwen3ASRConfig()
384
+
385
+ >>> # Initializing a model from the configuration
386
+ >>> model = Qwen3ASRForConditionalGeneration(configuration)
387
+
388
+ >>> # Accessing the model configuration
389
+ >>> configuration = model.config
390
+ ```"""
391
+
392
+ model_type = "qwen3_asr"
393
+ sub_configs = {
394
+ "thinker_config": Qwen3ASRThinkerConfig,
395
+ }
396
+
397
+ def __init__(
398
+ self,
399
+ thinker_config=None,
400
+ support_languages=None,
401
+ **kwargs,
402
+ ):
403
+ super().__init__(**kwargs)
404
+ if thinker_config is None:
405
+ thinker_config = {}
406
+
407
+ self.thinker_config = Qwen3ASRThinkerConfig(**thinker_config)
408
+ self.support_languages = support_languages
409
+
410
+ def get_text_config(self, decoder=False) -> "PretrainedConfig":
411
+ """
412
+ Returns the config that is meant to be used with text IO. On most models, it is the original config instance
413
+ itself. On specific composite models, it is under a set of valid names.
414
+
415
+ Args:
416
+ decoder (`Optional[bool]`, *optional*, defaults to `False`):
417
+ If set to `True`, then only search for decoder config names.
418
+ """
419
+ # Overridden for deeply nested config like Qwen2.5-Omni. We don't have any omni model
420
+ # except for Qwen yet. This has to be generalized if more deeply nested configs are
421
+ # added. NOTE: currently method used only by vLLM
422
+ return self.thinker_config.get_text_config()
423
+
424
+
425
+ __all__ = ["Qwen3ASRConfig", "Qwen3ASRThinkerConfig", "Qwen3ASRAudioEncoderConfig"]
qwen_asr/core/transformers_backend/modeling_qwen3_asr.py ADDED
@@ -0,0 +1,1361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import math
16
+ from dataclasses import dataclass
17
+ from typing import Callable, Optional, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+ from torch import nn
22
+ from torch.nn import functional as F
23
+
24
+ from transformers.activations import ACT2FN
25
+ from transformers.cache_utils import Cache, DynamicCache
26
+ from transformers.generation import GenerationMixin
27
+ from transformers.integrations import use_kernel_forward_from_hub
28
+ from transformers.masking_utils import create_causal_mask
29
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
30
+ from transformers.modeling_layers import GradientCheckpointingLayer
31
+ from transformers.modeling_outputs import (
32
+ BaseModelOutput,
33
+ BaseModelOutputWithPast,
34
+ MoeCausalLMOutputWithPast,
35
+ )
36
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
37
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
38
+ from transformers.processing_utils import Unpack
39
+ from transformers.utils import auto_docstring, can_return_tuple
40
+ from transformers.utils.deprecation import deprecate_kwarg
41
+ from transformers.utils.generic import TransformersKwargs, check_model_inputs
42
+
43
+ from .configuration_qwen3_asr import (
44
+ Qwen3ASRAudioEncoderConfig,
45
+ Qwen3ASRConfig,
46
+ Qwen3ASRThinkerConfig,
47
+ )
48
+
49
+
50
+ @use_kernel_forward_from_hub("RMSNorm")
51
+ class Qwen3ASRTextRMSNorm(nn.Module):
52
+ def __init__(self, hidden_size, eps: float = 1e-6) -> None:
53
+ """
54
+ Qwen3ASRTextRMSNorm is equivalent to T5LayerNorm
55
+ """
56
+ super().__init__()
57
+ self.weight = nn.Parameter(torch.ones(hidden_size))
58
+ self.variance_epsilon = eps
59
+
60
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
61
+ input_dtype = hidden_states.dtype
62
+ hidden_states = hidden_states.to(torch.float32)
63
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
64
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
65
+ return self.weight * hidden_states.to(input_dtype)
66
+
67
+ def extra_repr(self):
68
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
69
+
70
+
71
+ def rotate_half(x):
72
+ """Rotates half the hidden dims of the input."""
73
+ x1 = x[..., : x.shape[-1] // 2]
74
+ x2 = x[..., x.shape[-1] // 2 :]
75
+ return torch.cat((-x2, x1), dim=-1)
76
+
77
+
78
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
79
+ """
80
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
81
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
82
+ """
83
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
84
+ if n_rep == 1:
85
+ return hidden_states
86
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
87
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
88
+
89
+
90
+ def eager_attention_forward(
91
+ module: nn.Module,
92
+ query: torch.Tensor,
93
+ key: torch.Tensor,
94
+ value: torch.Tensor,
95
+ attention_mask: Optional[torch.Tensor],
96
+ scaling: float,
97
+ dropout: float = 0.0,
98
+ **kwargs: Unpack[TransformersKwargs],
99
+ ):
100
+ key_states = repeat_kv(key, module.num_key_value_groups)
101
+ value_states = repeat_kv(value, module.num_key_value_groups)
102
+
103
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
104
+ if attention_mask is not None:
105
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
106
+ attn_weights = attn_weights + causal_mask
107
+
108
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
109
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
110
+ attn_output = torch.matmul(attn_weights, value_states)
111
+ attn_output = attn_output.transpose(1, 2).contiguous()
112
+
113
+ return attn_output, attn_weights
114
+
115
+
116
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
117
+ """Applies Rotary Position Embedding to the query and key tensors.
118
+
119
+ Args:
120
+ q (`torch.Tensor`): The query tensor.
121
+ k (`torch.Tensor`): The key tensor.
122
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
123
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
124
+ position_ids (`torch.Tensor`, *optional*):
125
+ Deprecated and unused.
126
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
127
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
128
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
129
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
130
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
131
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
132
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
133
+ Returns:
134
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
135
+ """
136
+ cos = cos.unsqueeze(unsqueeze_dim)
137
+ sin = sin.unsqueeze(unsqueeze_dim)
138
+ q_embed = (q * cos) + (rotate_half(q) * sin)
139
+ k_embed = (k * cos) + (rotate_half(k) * sin)
140
+ return q_embed, k_embed
141
+
142
+
143
+ class Qwen3ASRTextAttention(nn.Module):
144
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
145
+
146
+ def __init__(self, config: Qwen3ASRConfig, layer_idx: int):
147
+ super().__init__()
148
+ self.config = config
149
+ self.layer_idx = layer_idx
150
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
151
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
152
+ self.scaling = self.head_dim**-0.5
153
+ self.attention_dropout = config.attention_dropout
154
+ self.is_causal = True
155
+
156
+ self.q_proj = nn.Linear(
157
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
158
+ )
159
+ self.k_proj = nn.Linear(
160
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
161
+ )
162
+ self.v_proj = nn.Linear(
163
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
164
+ )
165
+ self.o_proj = nn.Linear(
166
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
167
+ )
168
+ self.q_norm = Qwen3ASRTextRMSNorm(
169
+ self.head_dim, eps=config.rms_norm_eps
170
+ ) # unlike olmo, only on the head dim!
171
+ self.k_norm = Qwen3ASRTextRMSNorm(
172
+ self.head_dim, eps=config.rms_norm_eps
173
+ ) # thus post q_norm does not need reshape
174
+
175
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
176
+ def forward(
177
+ self,
178
+ hidden_states: torch.Tensor,
179
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
180
+ attention_mask: Optional[torch.Tensor],
181
+ past_key_values: Optional[Cache] = None,
182
+ cache_position: Optional[torch.LongTensor] = None,
183
+ **kwargs: Unpack[FlashAttentionKwargs],
184
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
185
+ input_shape = hidden_states.shape[:-1]
186
+ hidden_shape = (*input_shape, -1, self.head_dim)
187
+
188
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
189
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
190
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
191
+
192
+ cos, sin = position_embeddings
193
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
194
+
195
+ if past_key_values is not None:
196
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
197
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
198
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
199
+
200
+ attention_interface: Callable = eager_attention_forward
201
+ if self.config._attn_implementation != "eager":
202
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
203
+
204
+ attn_output, attn_weights = attention_interface(
205
+ self,
206
+ query_states,
207
+ key_states,
208
+ value_states,
209
+ attention_mask,
210
+ dropout=0.0 if not self.training else self.attention_dropout,
211
+ scaling=self.scaling,
212
+ **kwargs,
213
+ )
214
+
215
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
216
+ attn_output = self.o_proj(attn_output)
217
+ return attn_output, attn_weights
218
+
219
+
220
+ class Qwen3ASRTextMLP(nn.Module):
221
+ def __init__(self, config):
222
+ super().__init__()
223
+ self.config = config
224
+ self.hidden_size = config.hidden_size
225
+ self.intermediate_size = config.intermediate_size
226
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
227
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
228
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
229
+ self.act_fn = ACT2FN[config.hidden_act]
230
+
231
+ def forward(self, x):
232
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
233
+ return down_proj
234
+
235
+
236
+ class Qwen3ASRThinkerTextDecoderLayer(GradientCheckpointingLayer):
237
+ def __init__(self, config: Qwen3ASRConfig, layer_idx: int):
238
+ super().__init__()
239
+ self.hidden_size = config.hidden_size
240
+
241
+ self.self_attn = Qwen3ASRTextAttention(config=config, layer_idx=layer_idx)
242
+
243
+ self.mlp = Qwen3ASRTextMLP(config)
244
+ self.input_layernorm = Qwen3ASRTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
245
+ self.post_attention_layernorm = Qwen3ASRTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
246
+
247
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
248
+ def forward(
249
+ self,
250
+ hidden_states: torch.Tensor,
251
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
252
+ attention_mask: Optional[torch.Tensor] = None,
253
+ position_ids: Optional[torch.LongTensor] = None,
254
+ past_key_values: Optional[Cache] = None,
255
+ use_cache: Optional[bool] = False,
256
+ cache_position: Optional[torch.LongTensor] = None,
257
+ **kwargs: Unpack[TransformersKwargs],
258
+ ) -> torch.Tensor:
259
+ residual = hidden_states
260
+ hidden_states = self.input_layernorm(hidden_states)
261
+ # Self Attention
262
+ hidden_states, _ = self.self_attn(
263
+ hidden_states=hidden_states,
264
+ attention_mask=attention_mask,
265
+ position_ids=position_ids,
266
+ past_key_values=past_key_values,
267
+ use_cache=use_cache,
268
+ cache_position=cache_position,
269
+ position_embeddings=position_embeddings,
270
+ **kwargs,
271
+ )
272
+ hidden_states = residual + hidden_states
273
+
274
+ # Fully Connected
275
+ residual = hidden_states
276
+ hidden_states = self.post_attention_layernorm(hidden_states)
277
+ hidden_states = self.mlp(hidden_states)
278
+ hidden_states = residual + hidden_states
279
+ return hidden_states
280
+
281
+
282
+ @auto_docstring
283
+ class Qwen3ASRPreTrainedModel(PreTrainedModel):
284
+ config: Qwen3ASRConfig
285
+ base_model_prefix = "model"
286
+ supports_gradient_checkpointing = True
287
+ _skip_keys_device_placement = "past_key_values"
288
+ _supports_flash_attn = True
289
+ _supports_sdpa = True
290
+
291
+ _can_compile_fullgraph = True
292
+ _supports_attention_backend = True
293
+ _can_record_outputs = {
294
+ "attentions": Qwen3ASRTextAttention,
295
+ }
296
+
297
+
298
+ @dataclass
299
+ class Qwen3ASRThinkerCausalLMOutputWithPast(MoeCausalLMOutputWithPast):
300
+ r"""
301
+ Args:
302
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
303
+ The rope index difference between sequence length and multimodal rope.
304
+ """
305
+
306
+ rope_deltas: Optional[torch.LongTensor] = None
307
+
308
+
309
+ def _get_feat_extract_output_lengths(input_lengths):
310
+ """
311
+ Computes the output length of the convolutional layers and the output length of the audio encoder
312
+ """
313
+
314
+ input_lengths_leave = input_lengths % 100
315
+ feat_lengths = (input_lengths_leave - 1) // 2 + 1
316
+ output_lengths = ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
317
+ return output_lengths
318
+
319
+
320
+ class Qwen3ASRPreTrainedModelForConditionalGeneration(Qwen3ASRPreTrainedModel):
321
+ def _prepare_4d_causal_attention_mask_with_cache_position(
322
+ self,
323
+ attention_mask: torch.Tensor,
324
+ sequence_length: int,
325
+ target_length: int,
326
+ dtype: torch.dtype,
327
+ device: torch.device,
328
+ min_dtype: float,
329
+ cache_position: torch.Tensor,
330
+ batch_size: int,
331
+ ):
332
+ """
333
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
334
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
335
+
336
+ Args:
337
+ attention_mask (`torch.Tensor`):
338
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
339
+ sequence_length (`int`):
340
+ The sequence length being processed.
341
+ target_length (`int`):
342
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
343
+ dtype (`torch.dtype`):
344
+ The dtype to use for the 4D attention mask.
345
+ device (`torch.device`):
346
+ The device to place the 4D attention mask on.
347
+ min_dtype (`float`):
348
+ The minimum value representable with the dtype `dtype`.
349
+ cache_position (`torch.Tensor`):
350
+ Indices depicting the position of the input sequence tokens in the sequence.
351
+ batch_size (`torch.Tensor`):
352
+ Batch size.
353
+ """
354
+ if attention_mask is not None and attention_mask.dim() == 4:
355
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
356
+ causal_mask = attention_mask
357
+ else:
358
+ causal_mask = torch.full(
359
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
360
+ )
361
+ if sequence_length != 1:
362
+ causal_mask = torch.triu(causal_mask, diagonal=1)
363
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
364
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
365
+ if attention_mask is not None:
366
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
367
+ mask_length = attention_mask.shape[-1]
368
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
369
+ padding_mask = padding_mask == 0
370
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
371
+ padding_mask, min_dtype
372
+ )
373
+
374
+ return causal_mask
375
+
376
+
377
+ def get_chunked_index(
378
+ self, token_indices: torch.Tensor, tokens_per_chunk: int, remove_index: int
379
+ ) -> list[tuple[int, int]]:
380
+ """
381
+ Splits token index list into chunks based on token value ranges.
382
+
383
+ Given a list of token indices, returns a list of (start, end) index tuples representing
384
+ slices of the list where the token values fall within successive ranges of `t_ntoken_per_chunk`.
385
+
386
+ For example, if `t_ntoken_per_chunk` is 1000, the function will create chunks such that:
387
+ - the first chunk contains token values < 1000,
388
+ - the second chunk contains values >= 1000 and < 2000, and so on.
389
+
390
+ Parameters:
391
+ token_indices (`torch.Tensor` of shape `(seq_len, )`): A monotonically increasing list of
392
+ token index values.
393
+ t_ntoken_per_chunk (`int`): Number of tokens per chunk (used as the chunk size threshold).
394
+ remove_index (`int`) An index id to subtract from `token_indices` before chunking
395
+
396
+ Returns:
397
+ `list[tuple[int, int]]`: A list of tuples, each representing the start (inclusive)
398
+ and end (exclusive) indices of a chunk in `token_indices`.
399
+ """
400
+
401
+ def _iter():
402
+ i, start_idx = 0, 0 # skip bos token
403
+ current_chunk = 1
404
+ while i < len(token_indices): # skip eos token
405
+ if token_indices[i] - remove_index >= current_chunk * tokens_per_chunk:
406
+ yield (start_idx, i)
407
+ start_idx = i
408
+ current_chunk += 1
409
+ i += 1
410
+ yield (start_idx, len(token_indices))
411
+
412
+ return list(_iter())
413
+
414
+ def get_rope_index(
415
+ self,
416
+ attention_mask: Optional[torch.Tensor] = None,
417
+ ) -> tuple[torch.Tensor, torch.Tensor]:
418
+ """
419
+ Calculate the rope index in LLM.
420
+
421
+ Explanation:
422
+ Each embedding sequence contains text embedding.
423
+
424
+ Args:
425
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
426
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
427
+ it.
428
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
429
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
430
+
431
+ - 1 for tokens that are **not masked**,
432
+ - 0 for tokens that are **masked**.
433
+ audio_seqlens (`torch.LongTensor` of shape `(num_audios)`, *optional*):
434
+ The length of feature shape of each audio in LLM.
435
+
436
+ Returns:
437
+ position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
438
+ mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
439
+ """
440
+ mrope_position_deltas = []
441
+
442
+ position_ids = attention_mask.float().cumsum(-1) - 1
443
+ position_ids.masked_fill_(attention_mask == 0, 1)
444
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
445
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
446
+ mrope_position_deltas = max_position_ids + 1 - torch.sum(attention_mask, dim=-1, keepdim=True)
447
+
448
+ return position_ids, mrope_position_deltas
449
+
450
+
451
+ class Qwen3ASRAudioAttention(nn.Module):
452
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
453
+
454
+ def __init__(self, config):
455
+ super().__init__()
456
+ self.embed_dim = config.d_model
457
+ self.num_heads = config.encoder_attention_heads
458
+ self.dropout = config.attention_dropout
459
+ self.head_dim = self.embed_dim // self.num_heads
460
+ self.num_key_value_groups = 1 # needed for eager attention
461
+ self.config = config
462
+
463
+ if (self.head_dim * self.num_heads) != self.embed_dim:
464
+ raise ValueError(
465
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
466
+ f" and `num_heads`: {self.num_heads})."
467
+ )
468
+ self.scaling = self.head_dim**-0.5
469
+ self.attention_dropout = 0.0
470
+ self.is_decoder = False
471
+ self.is_causal = False
472
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
473
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
474
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
475
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
476
+
477
+ def forward(
478
+ self,
479
+ hidden_states: torch.Tensor,
480
+ cu_seqlens: Optional[torch.Tensor] = None,
481
+ attention_mask: Optional[torch.Tensor] = None,
482
+ **kwargs,
483
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
484
+ """Input shape: Batch x Time x Channel"""
485
+
486
+ seq_length, _ = hidden_states.size()
487
+
488
+ query_states = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
489
+ key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
490
+ value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
491
+
492
+ query_states = query_states.transpose(0, 1).unsqueeze(0)
493
+ key_states = key_states.transpose(0, 1).unsqueeze(0)
494
+ value_states = value_states.transpose(0, 1).unsqueeze(0)
495
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
496
+
497
+ attention_interface: Callable = eager_attention_forward
498
+ if self.config._attn_implementation != "eager":
499
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
500
+
501
+ attn_output, _ = attention_interface(
502
+ self,
503
+ query_states,
504
+ key_states,
505
+ value_states,
506
+ attention_mask=attention_mask,
507
+ dropout=0.0 if not self.training else self.attention_dropout,
508
+ scaling=self.scaling,
509
+ cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
510
+ cu_seq_lens_k=cu_seqlens,
511
+ max_length_q=max_seqlen,
512
+ max_length_k=max_seqlen,
513
+ is_causal=False,
514
+ **kwargs,
515
+ )
516
+
517
+ attn_output = attn_output.reshape(seq_length, -1).contiguous()
518
+ attn_output = self.out_proj(attn_output)
519
+
520
+ return attn_output
521
+
522
+
523
+ class Qwen3ASRAudioEncoderLayer(GradientCheckpointingLayer):
524
+ def __init__(self, config: Qwen3ASRAudioEncoderConfig):
525
+ super().__init__()
526
+ self.embed_dim = config.d_model
527
+ self.self_attn = Qwen3ASRAudioAttention(config)
528
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
529
+ self.dropout = config.dropout
530
+ self.activation_fn = ACT2FN[config.activation_function]
531
+ self.activation_dropout = config.activation_dropout
532
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
533
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
534
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
535
+
536
+ def forward(
537
+ self,
538
+ hidden_states: torch.Tensor,
539
+ cu_seqlens: torch.Tensor,
540
+ attention_mask: Optional[torch.Tensor] = None,
541
+ **kwargs,
542
+ ) -> torch.Tensor:
543
+ """
544
+ Args:
545
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
546
+ attention_mask (`torch.FloatTensor`): attention mask of size
547
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
548
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
549
+ `(encoder_attention_heads,)`.
550
+ output_attentions (`bool`, *optional*):
551
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
552
+ returned tensors for more detail.
553
+ """
554
+ residual = hidden_states
555
+ hidden_states = self.self_attn_layer_norm(hidden_states)
556
+ hidden_states = self.self_attn(
557
+ hidden_states=hidden_states,
558
+ cu_seqlens=cu_seqlens,
559
+ attention_mask=attention_mask,
560
+ **kwargs,
561
+ )
562
+ hidden_states = residual + hidden_states
563
+ residual = hidden_states
564
+ hidden_states = self.final_layer_norm(hidden_states)
565
+ hidden_states = self.fc1(hidden_states)
566
+ hidden_states = self.activation_fn(hidden_states)
567
+ hidden_states = self.fc2(hidden_states)
568
+ hidden_states = residual + hidden_states
569
+
570
+ if hidden_states.dtype == torch.float16:
571
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
572
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
573
+
574
+ outputs = (hidden_states,)
575
+
576
+ return outputs
577
+
578
+
579
+ class SinusoidsPositionEmbedding(nn.Module):
580
+ def __init__(self, length, channels, max_timescale=10000):
581
+ super().__init__()
582
+ if channels % 2 != 0:
583
+ raise ValueError("SinusoidsPositionEmbedding needs even channels input")
584
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
585
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float())
586
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
587
+ self.register_buffer(
588
+ "positional_embedding",
589
+ torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1),
590
+ persistent=False,
591
+ )
592
+
593
+ def forward(self, seqlen: int):
594
+ return self.positional_embedding[:seqlen, :]
595
+
596
+
597
+ @auto_docstring(
598
+ custom_intro="""
599
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
600
+ [`Qwen3ASRAudioEncoderLayer`].
601
+ """
602
+ )
603
+ class Qwen3ASRAudioEncoder(Qwen3ASRPreTrainedModel):
604
+ config: Qwen3ASRAudioEncoderConfig
605
+ main_input_name = "input_features"
606
+ _no_split_modules = ["Qwen3ASRAudioEncoderLayer"]
607
+ _supports_sdpa = True
608
+
609
+ def __init__(self, config: Qwen3ASRAudioEncoderConfig):
610
+ super().__init__(config)
611
+ self.dropout = config.dropout
612
+
613
+ embed_dim = config.d_model
614
+ self.num_mel_bins = config.num_mel_bins
615
+ self.max_source_positions = config.max_source_positions
616
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
617
+ self.n_window = config.n_window
618
+ self.positional_embedding = SinusoidsPositionEmbedding(self.max_source_positions, embed_dim)
619
+ self.layers = nn.ModuleList([Qwen3ASRAudioEncoderLayer(config) for _ in range(config.encoder_layers)])
620
+ self.ln_post = nn.LayerNorm(config.d_model)
621
+ self.gradient_checkpointing = False
622
+ self.conv2d1 = nn.Conv2d(1, config.downsample_hidden_size, 3, 2, padding=1)
623
+ self.conv2d2 = nn.Conv2d(config.downsample_hidden_size, config.downsample_hidden_size, 3, 2, padding=1)
624
+ self.conv2d3 = nn.Conv2d(config.downsample_hidden_size, config.downsample_hidden_size, 3, 2, padding=1)
625
+ self.conv_out = nn.Linear(
626
+ config.downsample_hidden_size * ((((config.num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2),
627
+ config.d_model,
628
+ bias=False,
629
+ )
630
+ self.proj1 = nn.Linear(config.d_model, config.d_model)
631
+ self.act = ACT2FN[config.activation_function]
632
+ self.proj2 = nn.Linear(config.d_model, config.output_dim)
633
+ self.n_window_infer = self.config.n_window_infer
634
+ self.conv_chunksize = self.config.conv_chunksize
635
+ # Initialize weights and apply final processing
636
+ self.post_init()
637
+
638
+ def _freeze_parameters(self):
639
+ for param in self.parameters():
640
+ param.requires_grad = False
641
+ self._requires_grad = False
642
+
643
+ def get_input_embeddings(self) -> nn.Module:
644
+ return self.conv1
645
+
646
+ def set_input_embeddings(self, value: nn.Module):
647
+ self.conv1 = value
648
+
649
+ def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
650
+ # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
651
+ # NOTE: the created attention masl only approximates the ragged FA2 attention by
652
+ # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
653
+ # blocks. Though it will not be a 100% match for FA2's `varlen` path
654
+ if self.config._attn_implementation == "flash_attention_2":
655
+ return None
656
+
657
+ seq_length = inputs_tensor.shape[0]
658
+ attention_mask = torch.full(
659
+ [1, 1, seq_length, seq_length],
660
+ torch.finfo(inputs_tensor.dtype).min,
661
+ device=inputs_tensor.device,
662
+ dtype=inputs_tensor.dtype,
663
+ )
664
+ for i in range(1, len(cu_seqlens)):
665
+ attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
666
+ return attention_mask
667
+
668
+ @auto_docstring
669
+ def forward(
670
+ self,
671
+ input_features,
672
+ feature_lens=None,
673
+ aftercnn_lens=None,
674
+ ):
675
+ r"""
676
+ feature_lens (`torch.LongTensor` of shape `(batch_size,)`):
677
+ mel length
678
+ aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`):
679
+ mel length after cnn
680
+ """
681
+ aftercnn_lens = _get_feat_extract_output_lengths(feature_lens)
682
+ chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long()
683
+
684
+ chunk_lengths = torch.tensor(
685
+ [self.n_window * 2] * chunk_num.sum(),
686
+ dtype=torch.long,
687
+ device=feature_lens.device,
688
+ )
689
+ tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:]
690
+ chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2)
691
+ chunk_lengths[chunk_lengths == 0] = self.n_window * 2
692
+
693
+ chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0)
694
+ padded_feature = nn.utils.rnn.pad_sequence(chunk_list, batch_first=True).transpose(1, 2)
695
+ feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths)
696
+ padded_mask_after_cnn = nn.utils.rnn.pad_sequence(
697
+ [torch.ones(length, dtype=torch.bool, device=padded_feature.device) for length in feature_lens_after_cnn],
698
+ batch_first=True,
699
+ )
700
+ padded_feature = padded_feature.unsqueeze(1)
701
+ # Split to chunk to avoid OOM during convolution
702
+ padded_embeds = []
703
+ for chunk in padded_feature.split(self.conv_chunksize, dim=0):
704
+ padded_embed = F.gelu(self.conv2d1(chunk))
705
+ padded_embed = F.gelu(self.conv2d2(padded_embed))
706
+ padded_embed = F.gelu(self.conv2d3(padded_embed))
707
+ padded_embeds.append(padded_embed)
708
+ padded_embed = torch.cat(padded_embeds, dim=0)
709
+ b, c, f, t = padded_embed.size()
710
+ padded_embed = self.conv_out(padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f))
711
+
712
+ positional_embedding = (
713
+ self.positional_embedding.positional_embedding[: padded_embed.shape[1], :]
714
+ .unsqueeze(0)
715
+ .to(padded_embed.dtype)
716
+ )
717
+ padded_embed = padded_embed + positional_embedding
718
+ hidden_states = padded_embed[padded_mask_after_cnn]
719
+ cu_chunk_lens = [0]
720
+ window_aftercnn = padded_mask_after_cnn.shape[-1] * (self.n_window_infer // (self.n_window * 2))
721
+ for cnn_len in aftercnn_lens:
722
+ cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn)
723
+ remainder = cnn_len % window_aftercnn
724
+ if remainder != 0:
725
+ cu_chunk_lens += [remainder]
726
+ cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(-1, dtype=torch.int32)
727
+
728
+ for encoder_layer in self.layers:
729
+ layer_outputs = encoder_layer(
730
+ hidden_states,
731
+ cu_seqlens,
732
+ )
733
+
734
+ hidden_states = layer_outputs[0]
735
+
736
+ hidden_states = self.ln_post(hidden_states)
737
+ hidden_states = self.proj1(hidden_states)
738
+ hidden_states = self.act(hidden_states)
739
+ hidden_states = self.proj2(hidden_states)
740
+ return BaseModelOutput(last_hidden_state=hidden_states)
741
+
742
+ def padded_and_mask_function(self, tensor_list, tensor_len, padding_value=0, padding_side="right"):
743
+ """
744
+ Pads a sequence of tensors to their maximum length on indicated `padding_side`.
745
+ Then prepares a mask so that pad tokens are not attended to.
746
+ """
747
+ max_len = tensor_len.max()
748
+ dim = tensor_list[0].shape[0]
749
+ padded_tensor = torch.full(
750
+ size=(len(tensor_list), dim, max_len),
751
+ fill_value=padding_value,
752
+ dtype=self.dtype,
753
+ device=tensor_list[0].device,
754
+ )
755
+
756
+ batch_mask = torch.zeros(
757
+ (len(tensor_len), max_len),
758
+ dtype=torch.long,
759
+ device=padded_tensor.device,
760
+ )
761
+ for i, length in enumerate(tensor_len):
762
+ batch_mask[i, :length] = 1
763
+ padded_tensor[i, :, :length] = tensor_list[i]
764
+
765
+ feature_lens_after_cnn = (tensor_len - 1) // 2 + 1
766
+ max_len_after_cnn = feature_lens_after_cnn.max()
767
+ batch_mask_after_cnn = torch.zeros(
768
+ (len(tensor_len), max_len_after_cnn),
769
+ dtype=torch.long,
770
+ device=padded_tensor.device,
771
+ )
772
+ for i, length in enumerate(feature_lens_after_cnn):
773
+ batch_mask_after_cnn[i, :length] = 1
774
+ return (
775
+ padded_tensor,
776
+ batch_mask.unsqueeze(1),
777
+ batch_mask_after_cnn.bool(),
778
+ )
779
+
780
+
781
+ class Qwen3ASRThinkerTextRotaryEmbedding(nn.Module):
782
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
783
+
784
+ def __init__(self, config: Qwen3ASRConfig, device=None):
785
+ super().__init__()
786
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
787
+ self.rope_type = config.rope_scaling.get("rope_type", "default")
788
+ else:
789
+ self.rope_type = "default"
790
+ self.max_seq_len_cached = config.max_position_embeddings
791
+ self.original_max_seq_len = config.max_position_embeddings
792
+
793
+ self.config = config
794
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
795
+
796
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
797
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
798
+ self.original_inv_freq = self.inv_freq
799
+
800
+ self.mrope_section = config.rope_scaling.get("mrope_section", [24, 20, 20])
801
+
802
+ def apply_interleaved_mrope(self, freqs, mrope_section):
803
+ """Apply interleaved MRoPE to 3D rotary embeddings.
804
+ Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
805
+ interleaved [THTHWHTHW...TT], preserving frequency continuity.
806
+ args:
807
+ x: (3, bs, seq_len, head_dim // 2)
808
+ mrope_section: (3,)
809
+ returns:
810
+ x_t: (bs, seq_len, head_dim // 2)
811
+ """
812
+ freqs_t = freqs[0] # just overwrite the first dimension T
813
+ for dim, offset in enumerate((1, 2), start=1): # H, W
814
+ length = mrope_section[dim] * 3
815
+ idx = slice(offset, length, 3)
816
+ freqs_t[..., idx] = freqs[dim, ..., idx]
817
+ return freqs_t
818
+
819
+ @torch.no_grad()
820
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
821
+ def forward(self, x, position_ids):
822
+ # In contrast to other models, Qwen3ASRThinker has different position ids for the grids
823
+ # So we expand the inv_freq to shape (3, ...)
824
+ if position_ids.ndim == 2:
825
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
826
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
827
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
828
+
829
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
830
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
831
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
832
+ freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
833
+ emb = torch.cat((freqs, freqs), dim=-1)
834
+ cos = emb.cos() * self.attention_scaling
835
+ sin = emb.sin() * self.attention_scaling
836
+
837
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
838
+
839
+
840
+ class Qwen3ASRThinkerTextMLP(nn.Module):
841
+ def __init__(self, config, intermediate_size=None):
842
+ super().__init__()
843
+ self.config = config
844
+ self.hidden_size = config.hidden_size
845
+ self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
846
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
847
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
848
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
849
+ self.act_fn = ACT2FN[config.hidden_act]
850
+
851
+ def forward(self, x):
852
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
853
+ return down_proj
854
+
855
+
856
+ @use_kernel_forward_from_hub("RMSNorm")
857
+ class Qwen3ASRThinkerTextRMSNorm(nn.Module):
858
+ def __init__(self, hidden_size, eps=1e-6):
859
+ """
860
+ Qwen3ASRThinkerTextRMSNorm is equivalent to T5LayerNorm
861
+ """
862
+ super().__init__()
863
+ self.weight = nn.Parameter(torch.ones(hidden_size))
864
+ self.variance_epsilon = eps
865
+
866
+ def forward(self, hidden_states):
867
+ input_dtype = hidden_states.dtype
868
+ hidden_states = hidden_states.to(torch.float32)
869
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
870
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
871
+ return self.weight * hidden_states.to(input_dtype)
872
+
873
+ def extra_repr(self):
874
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
875
+
876
+
877
+ class Qwen3ASRThinkerTextAttention(nn.Module):
878
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
879
+
880
+ def __init__(self, config, layer_idx):
881
+ super().__init__()
882
+ self.config = config
883
+ self.layer_idx = layer_idx
884
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
885
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
886
+ self.scaling = self.head_dim**-0.5
887
+ self.attention_dropout = config.attention_dropout
888
+ self.is_causal = True
889
+
890
+ self.q_proj = nn.Linear(
891
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
892
+ )
893
+ self.k_proj = nn.Linear(
894
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
895
+ )
896
+ self.v_proj = nn.Linear(
897
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
898
+ )
899
+ self.o_proj = nn.Linear(
900
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
901
+ )
902
+ self.q_norm = Qwen3ASRThinkerTextRMSNorm(
903
+ self.head_dim, eps=config.rms_norm_eps
904
+ ) # unlike olmo, only on the head dim!
905
+ self.k_norm = Qwen3ASRThinkerTextRMSNorm(
906
+ self.head_dim, eps=config.rms_norm_eps
907
+ ) # thus post q_norm does not need reshape
908
+ self.sliding_window = None
909
+
910
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
911
+ def forward(
912
+ self,
913
+ hidden_states: torch.Tensor,
914
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
915
+ attention_mask: Optional[torch.Tensor],
916
+ past_key_values: Optional[Cache] = None,
917
+ cache_position: Optional[torch.LongTensor] = None,
918
+ **kwargs: Unpack[FlashAttentionKwargs],
919
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
920
+ input_shape = hidden_states.shape[:-1]
921
+ hidden_shape = (*input_shape, -1, self.head_dim)
922
+
923
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
924
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
925
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
926
+
927
+ cos, sin = position_embeddings
928
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
929
+
930
+ if past_key_values is not None:
931
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
932
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
933
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
934
+
935
+ attention_interface: Callable = eager_attention_forward
936
+ if self.config._attn_implementation != "eager":
937
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
938
+
939
+ attn_output, attn_weights = attention_interface(
940
+ self,
941
+ query_states,
942
+ key_states,
943
+ value_states,
944
+ attention_mask,
945
+ dropout=0.0 if not self.training else self.attention_dropout,
946
+ scaling=self.scaling,
947
+ sliding_window=self.sliding_window, # diff with Llama
948
+ **kwargs,
949
+ )
950
+
951
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
952
+ attn_output = self.o_proj(attn_output)
953
+ return attn_output, attn_weights
954
+
955
+
956
+ @auto_docstring(
957
+ custom_intro=(
958
+ "Text part of Qwen3ASRThinker, "
959
+ )
960
+ )
961
+ class Qwen3ASRThinkerTextModel(Qwen3ASRPreTrainedModel):
962
+ config: Qwen3ASRConfig
963
+ _no_split_modules = ["Qwen3ASRThinkerTextDecoderLayer"]
964
+ config_class = Qwen3ASRConfig
965
+ _can_record_outputs = {
966
+ "hidden_states": Qwen3ASRThinkerTextDecoderLayer,
967
+ "attentions": Qwen3ASRThinkerTextAttention,
968
+ }
969
+
970
+ def __init__(self, config: Qwen3ASRConfig):
971
+ super().__init__(config)
972
+ self.padding_idx = config.pad_token_id
973
+ self.vocab_size = config.vocab_size
974
+
975
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
976
+ self.layers = nn.ModuleList(
977
+ [Qwen3ASRThinkerTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
978
+ )
979
+ self.norm = Qwen3ASRTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
980
+ self.rotary_emb = Qwen3ASRThinkerTextRotaryEmbedding(config)
981
+ self.gradient_checkpointing = False
982
+
983
+ # Initialize weights and apply final processing
984
+ self.post_init()
985
+
986
+ @check_model_inputs()
987
+ @auto_docstring
988
+ def forward(
989
+ self,
990
+ input_ids: Optional[torch.LongTensor] = None,
991
+ attention_mask: Optional[torch.Tensor] = None,
992
+ position_ids: Optional[torch.LongTensor] = None,
993
+ past_key_values: Optional[Cache] = None,
994
+ inputs_embeds: Optional[torch.FloatTensor] = None,
995
+ use_cache: Optional[bool] = None,
996
+ cache_position: Optional[torch.LongTensor] = None,
997
+ **kwargs: Unpack[FlashAttentionKwargs],
998
+ ) -> Union[tuple, BaseModelOutputWithPast]:
999
+ if (input_ids is None) ^ (inputs_embeds is not None):
1000
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1001
+
1002
+ # torch.jit.trace() doesn't support cache objects in the output
1003
+ if use_cache and past_key_values is None and not torch.jit.is_tracing():
1004
+ past_key_values = DynamicCache(config=self.config)
1005
+
1006
+ if inputs_embeds is None:
1007
+ inputs_embeds = self.embed_tokens(input_ids)
1008
+
1009
+ if cache_position is None:
1010
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1011
+ cache_position = torch.arange(
1012
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1013
+ )
1014
+
1015
+ # the hard coded `3` is for temporal, height and width.
1016
+ if position_ids is None:
1017
+ position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
1018
+ elif position_ids.ndim == 2:
1019
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
1020
+
1021
+ if position_ids.ndim == 3 and position_ids.shape[0] == 4:
1022
+ text_position_ids = position_ids[0]
1023
+ position_ids = position_ids[1:]
1024
+ else:
1025
+ text_position_ids = position_ids[0]
1026
+
1027
+ attention_mask = create_causal_mask(
1028
+ config=self.config,
1029
+ input_embeds=inputs_embeds,
1030
+ attention_mask=attention_mask,
1031
+ cache_position=cache_position,
1032
+ past_key_values=past_key_values,
1033
+ position_ids=text_position_ids,
1034
+ )
1035
+
1036
+ hidden_states = inputs_embeds
1037
+
1038
+ # create position embeddings to be shared across the decoder layers
1039
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
1040
+
1041
+ # decoder layers
1042
+ for layer_idx, decoder_layer in enumerate(self.layers):
1043
+ layer_outputs = decoder_layer(
1044
+ hidden_states,
1045
+ attention_mask=attention_mask,
1046
+ position_ids=text_position_ids,
1047
+ past_key_values=past_key_values,
1048
+ cache_position=cache_position,
1049
+ position_embeddings=position_embeddings,
1050
+ **kwargs,
1051
+ )
1052
+ hidden_states = layer_outputs
1053
+
1054
+ hidden_states = self.norm(hidden_states)
1055
+
1056
+ return BaseModelOutputWithPast(
1057
+ last_hidden_state=hidden_states,
1058
+ past_key_values=past_key_values,
1059
+ )
1060
+
1061
+
1062
+ @auto_docstring(
1063
+ custom_intro="""
1064
+ The Qwen3ASRThinker model which consists of a audio backbone and a language model.
1065
+ """
1066
+ )
1067
+ class Qwen3ASRThinkerForConditionalGeneration(Qwen3ASRPreTrainedModelForConditionalGeneration, GenerationMixin):
1068
+ config: Qwen3ASRThinkerConfig
1069
+ base_model_prefix = "thinker"
1070
+ _tied_weights_keys = ["model.embed_tokens.weight", "lm_head.weight"]
1071
+ _no_split_modules = [
1072
+ "Qwen3ASRAudioEncoderLayer",
1073
+ "Qwen3ASRThinkerTextDecoderLayer",
1074
+ ]
1075
+ _can_record_outputs = {
1076
+ "hidden_states": Qwen3ASRThinkerTextDecoderLayer,
1077
+ "attentions": Qwen3ASRThinkerTextAttention,
1078
+ }
1079
+
1080
+ def __init__(self, config):
1081
+ super().__init__(config)
1082
+ self.audio_tower = Qwen3ASRAudioEncoder._from_config(config.audio_config)
1083
+ self.vocab_size = config.text_config.vocab_size
1084
+ self.model = Qwen3ASRThinkerTextModel._from_config(config.text_config)
1085
+ if "forced_aligner" in config.model_type:
1086
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.classify_num, bias=False)
1087
+ else:
1088
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
1089
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
1090
+ self.rope_deltas = None
1091
+ self.post_init()
1092
+
1093
+ def get_input_embeddings(self):
1094
+ return self.model.get_input_embeddings()
1095
+
1096
+ def set_input_embeddings(self, value):
1097
+ self.model.set_input_embeddings(value)
1098
+
1099
+ def get_audio_features(
1100
+ self,
1101
+ input_features: torch.FloatTensor,
1102
+ feature_attention_mask: Optional[torch.LongTensor] = None,
1103
+ audio_feature_lengths: Optional[torch.LongTensor] = None,
1104
+ ):
1105
+ """
1106
+ Encodes audios into continuous embeddings that can be forwarded to the language model.
1107
+
1108
+ Args:
1109
+ input_features (`torch.FloatTensor`):
1110
+ The tensors corresponding to the input audios.
1111
+ feature_attention_mask (`torch.LongTensor`, *optional*):
1112
+ Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
1113
+ audio_feature_lengths (`torch.LongTensor` of shape `(num_audios)`, *optional*):
1114
+ The length of feature shape of each audio in LLM.
1115
+ """
1116
+ if feature_attention_mask is not None:
1117
+ audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
1118
+ input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0)
1119
+ else:
1120
+ audio_feature_lengths = None
1121
+
1122
+ feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
1123
+ audio_outputs = self.audio_tower(
1124
+ input_features,
1125
+ feature_lens=feature_lens,
1126
+ )
1127
+ audio_features = audio_outputs.last_hidden_state
1128
+
1129
+ return audio_features
1130
+
1131
+ def get_placeholder_mask(
1132
+ self,
1133
+ input_ids: torch.LongTensor,
1134
+ inputs_embeds: torch.FloatTensor,
1135
+ ):
1136
+ """
1137
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
1138
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
1139
+ """
1140
+ if input_ids is None:
1141
+ special_audio_mask = (
1142
+ inputs_embeds
1143
+ == self.get_input_embeddings()(
1144
+ torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
1145
+ )
1146
+ ).all(-1)
1147
+ else:
1148
+ special_audio_mask = input_ids == self.config.audio_token_id
1149
+
1150
+ special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
1151
+ return special_audio_mask
1152
+
1153
+ @can_return_tuple
1154
+ @auto_docstring
1155
+ def forward(
1156
+ self,
1157
+ input_ids=None,
1158
+ input_features=None,
1159
+ attention_mask=None,
1160
+ feature_attention_mask=None,
1161
+ audio_feature_lengths=None,
1162
+ position_ids=None,
1163
+ past_key_values=None,
1164
+ inputs_embeds=None,
1165
+ rope_deltas=None,
1166
+ labels=None,
1167
+ use_cache=None,
1168
+ cache_position=None,
1169
+ **kwargs,
1170
+ ) -> Union[tuple, Qwen3ASRThinkerCausalLMOutputWithPast]:
1171
+ r"""
1172
+ feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`, *optional*):
1173
+ Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
1174
+ - 1 for tokens that are **not masked**,
1175
+ - 0 for tokens that are **masked**.
1176
+ audio_feature_lengths (`torch.LongTensor` of shape `(num_audios)`, *optional*):
1177
+ The length of feature shape of each audio in LLM.
1178
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
1179
+ The rope index difference between sequence length and multimodal rope.
1180
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1181
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1182
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1183
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1184
+ """
1185
+
1186
+ if inputs_embeds is None:
1187
+ # 1. Extract the input embeddings
1188
+ inputs_embeds = self.get_input_embeddings()(input_ids)
1189
+
1190
+ # 2. Merge text, audios
1191
+ if input_features is not None:
1192
+ audio_features = self.get_audio_features(
1193
+ input_features,
1194
+ feature_attention_mask=feature_attention_mask,
1195
+ audio_feature_lengths=audio_feature_lengths,
1196
+ )
1197
+ audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
1198
+ audio_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
1199
+ inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
1200
+
1201
+ if feature_attention_mask is not None:
1202
+ audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
1203
+ else:
1204
+ audio_feature_lengths = None
1205
+
1206
+ if attention_mask is not None and position_ids is None:
1207
+ if (
1208
+ cache_position is None
1209
+ or (cache_position is not None and cache_position[0] == 0)
1210
+ or self.rope_deltas is None
1211
+ ):
1212
+ delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1)
1213
+ position_ids, rope_deltas = self.get_rope_index(
1214
+ attention_mask,
1215
+ )
1216
+ rope_deltas = rope_deltas - delta0
1217
+ self.rope_deltas = rope_deltas
1218
+ else:
1219
+ batch_size, seq_length = input_ids.shape
1220
+ delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
1221
+ position_ids = torch.arange(seq_length, device=input_ids.device)
1222
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
1223
+ position_ids = position_ids.add(delta)
1224
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
1225
+
1226
+ outputs = self.model(
1227
+ attention_mask=attention_mask,
1228
+ position_ids=position_ids,
1229
+ past_key_values=past_key_values,
1230
+ inputs_embeds=inputs_embeds,
1231
+ use_cache=use_cache,
1232
+ cache_position=cache_position,
1233
+ **kwargs,
1234
+ )
1235
+
1236
+ hidden_states = outputs[0]
1237
+ logits = self.lm_head(hidden_states)
1238
+
1239
+ loss = None
1240
+ if labels is not None:
1241
+ loss = self.loss_function(
1242
+ logits=logits, labels=labels, vocab_size=self.config.get_text_config().vocab_size
1243
+ )
1244
+
1245
+ return Qwen3ASRThinkerCausalLMOutputWithPast(
1246
+ loss=loss,
1247
+ logits=logits,
1248
+ hidden_states=outputs.hidden_states,
1249
+ attentions=outputs.attentions,
1250
+ past_key_values=outputs.past_key_values,
1251
+ rope_deltas=self.rope_deltas,
1252
+ )
1253
+
1254
+ def prepare_inputs_for_generation(
1255
+ self,
1256
+ input_ids,
1257
+ past_key_values=None,
1258
+ attention_mask=None,
1259
+ inputs_embeds=None,
1260
+ cache_position=None,
1261
+ position_ids=None,
1262
+ use_cache=True,
1263
+ input_features=None,
1264
+ feature_attention_mask=None,
1265
+ **kwargs,
1266
+ ):
1267
+ model_inputs = super().prepare_inputs_for_generation(
1268
+ input_ids,
1269
+ past_key_values=past_key_values,
1270
+ attention_mask=attention_mask,
1271
+ inputs_embeds=inputs_embeds,
1272
+ cache_position=cache_position,
1273
+ position_ids=position_ids,
1274
+ use_cache=use_cache,
1275
+ input_features=input_features,
1276
+ feature_attention_mask=feature_attention_mask,
1277
+ **kwargs,
1278
+ )
1279
+
1280
+ model_inputs["position_ids"] = None
1281
+
1282
+ if cache_position[0] != 0:
1283
+ model_inputs["input_features"] = None
1284
+
1285
+ return model_inputs
1286
+
1287
+
1288
+ @auto_docstring
1289
+ class Qwen3ASRThinkerTextPreTrainedModel(PreTrainedModel):
1290
+ config = Qwen3ASRConfig
1291
+ base_model_prefix = "model"
1292
+ supports_gradient_checkpointing = True
1293
+ _no_split_modules = ["Qwen3ASRThinkerTextDecoderLayer"]
1294
+ _skip_keys_device_placement = ["past_key_values"]
1295
+ _supports_flash_attn = True
1296
+ _supports_sdpa = True
1297
+ _supports_flex_attn = True
1298
+ _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
1299
+ _supports_attention_backend = True
1300
+ _can_record_outputs = {
1301
+ "hidden_states": Qwen3ASRThinkerTextDecoderLayer,
1302
+ "attentions": Qwen3ASRThinkerTextAttention,
1303
+ }
1304
+ config_class = Qwen3ASRConfig
1305
+
1306
+
1307
+ class Qwen3ASRForConditionalGeneration(Qwen3ASRPreTrainedModel, GenerationMixin):
1308
+ config_class = Qwen3ASRConfig
1309
+
1310
+ def __init__(self, config: Qwen3ASRConfig):
1311
+ super().__init__(config)
1312
+ self.config = config
1313
+
1314
+ self.thinker = Qwen3ASRThinkerForConditionalGeneration._from_config(config.thinker_config)
1315
+ self.post_init()
1316
+
1317
+ def get_support_languages(self):
1318
+ return self.config.support_languages
1319
+
1320
+ @torch.no_grad()
1321
+ def generate(
1322
+ self,
1323
+ input_ids: Optional[torch.Tensor] = None,
1324
+ max_new_tokens: int = 8192,
1325
+ eos_token_id: int | list[int] = [151645, 151643],
1326
+ **kwargs,
1327
+ ):
1328
+ shared_kwargs = {}
1329
+ thinker_kwargs = {
1330
+ "max_new_tokens": max_new_tokens,
1331
+ "eos_token_id": eos_token_id,
1332
+ }
1333
+
1334
+ for key, value in kwargs.items():
1335
+ # Process special input values
1336
+ if key == "feature_attention_mask":
1337
+ thinker_kwargs[key] = value
1338
+ elif key in ("input_features", "attention_mask"):
1339
+ thinker_kwargs[key] = value
1340
+ # Put other key to shared kwargs
1341
+ else:
1342
+ shared_kwargs[key] = value
1343
+
1344
+ # Merge kwargs
1345
+ for key, value in shared_kwargs.items():
1346
+ if key not in thinker_kwargs:
1347
+ thinker_kwargs[key] = value
1348
+
1349
+ thinker_result = self.thinker.generate(input_ids=input_ids, return_dict_in_generate=True, **thinker_kwargs)
1350
+
1351
+ return thinker_result
1352
+
1353
+
1354
+ __all__ = [
1355
+ "Qwen3ASRForConditionalGeneration",
1356
+ "Qwen3ASRThinkerTextModel",
1357
+ "Qwen3ASRThinkerForConditionalGeneration",
1358
+ "Qwen3ASRPreTrainedModel",
1359
+ "Qwen3ASRPreTrainedModelForConditionalGeneration",
1360
+ "Qwen3ASRThinkerTextPreTrainedModel",
1361
+ ]
qwen_asr/core/transformers_backend/processing_qwen3_asr.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import re
16
+
17
+ import numpy as np
18
+
19
+ from transformers.audio_utils import AudioInput
20
+ from transformers.feature_extraction_utils import BatchFeature
21
+ from transformers.processing_utils import ProcessingKwargs, ProcessorMixin
22
+ from transformers.tokenization_utils_base import TextInput
23
+
24
+
25
+ class Qwen3ASRProcessorKwargs(ProcessingKwargs, total=False):
26
+ _defaults = {
27
+ "text_kwargs": {
28
+ "padding": False,
29
+ "padding_side": "left",
30
+ },
31
+ "audio_kwargs": {
32
+ "sampling_rate": 16000,
33
+ "padding": True,
34
+ "return_attention_mask": True,
35
+ },
36
+ }
37
+
38
+
39
+ def _get_feat_extract_output_lengths(input_lengths):
40
+ """
41
+ Computes the output length of the convolutional layers and the output length of the audio encoder
42
+ """
43
+
44
+ input_lengths_leave = input_lengths % 100
45
+ feat_lengths = (input_lengths_leave - 1) // 2 + 1
46
+ output_lengths = ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
47
+ return output_lengths
48
+
49
+
50
+ class Qwen3ASRProcessor(ProcessorMixin):
51
+ r"""
52
+ Constructs a Qwen3ASR processor.
53
+ [`Qwen3ASRProcessor`] offers all the functionalities of [`WhisperFeatureExtractor`], and [`Qwen2TokenizerFast`]. See the
54
+ [`~Qwen3ASRProcessor.__call__`] and [`~Qwen3ASRProcessor.decode`] for more information.
55
+
56
+ Args:
57
+ feature_extractor ([`WhisperFeatureExtractor`], *optional*):
58
+ The audio feature extractor.
59
+ tokenizer ([`Qwen2TokenizerFast`], *optional*):
60
+ The text tokenizer.
61
+ chat_template (`Optional[str]`, *optional*):
62
+ The Jinja template to use for formatting the conversation. If not provided, the default chat template is used.
63
+ """
64
+
65
+ attributes = ["feature_extractor", "tokenizer"]
66
+ feature_extractor_class = "WhisperFeatureExtractor"
67
+ tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
68
+
69
+ def __init__(
70
+ self, feature_extractor=None, tokenizer=None, chat_template=None
71
+ ):
72
+ super().__init__(feature_extractor, tokenizer, chat_template=chat_template)
73
+ self.audio_token = self.tokenizer.audio_token
74
+ self.audio_bos_token = self.tokenizer.audio_bos_token
75
+ self.audio_eos_token = self.tokenizer.audio_eos_token
76
+
77
+ def __call__(
78
+ self,
79
+ text: TextInput = None,
80
+ audio: AudioInput = None,
81
+ **kwargs,
82
+ ) -> BatchFeature:
83
+ """
84
+ Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `text`
85
+ and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
86
+ the text. To prepare the audio(s), this method forwards the `audio` and `kwargs` arguments to
87
+ WhisperFeatureExtractor's [`~WhisperFeatureExtractor.__call__`] if `audio` is not `None`. Please refer to the doctsring
88
+ of the above two methods for more information.
89
+
90
+ Args:
91
+ text (`str`, `List[str]`, `List[List[str]]`):
92
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
93
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
94
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
95
+ audio (`np.ndarray`, `List[np.ndarray]`):
96
+ The audio or batch of audio to be prepared. Each audio can be a NumPy array.
97
+ """
98
+
99
+ if text is None:
100
+ raise ValueError("You need to specify either a `text` input to process.")
101
+
102
+ output_kwargs = self._merge_kwargs(
103
+ Qwen3ASRProcessorKwargs,
104
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
105
+ **kwargs,
106
+ )
107
+
108
+ if audio is not None:
109
+ output_kwargs["audio_kwargs"]["padding"] = True
110
+ output_kwargs["audio_kwargs"]["truncation"] = False
111
+ audio_inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"])
112
+ audio_inputs["feature_attention_mask"] = audio_inputs.pop(
113
+ "attention_mask"
114
+ ) # rename feature_attention_mask to prevent conflicts later on
115
+ audio_inputs["input_features"] = audio_inputs.pop(
116
+ "input_features"
117
+ ) # rename input_features to prevent conflicts later on
118
+ audio_lengths = iter(_get_feat_extract_output_lengths(audio_inputs["feature_attention_mask"].sum(-1)))
119
+ else:
120
+ audio_inputs = {}
121
+ audio_lengths = iter([])
122
+
123
+ if not isinstance(text, list):
124
+ text = [text]
125
+
126
+ text = self.replace_multimodal_special_tokens(
127
+ text,
128
+ audio_lengths,
129
+ )
130
+
131
+ texts_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
132
+
133
+ return BatchFeature(
134
+ data={**texts_inputs, **audio_inputs},
135
+ tensor_type=kwargs.get("return_tensors"),
136
+ )
137
+
138
+ def replace_multimodal_special_tokens(
139
+ self,
140
+ text,
141
+ audio_lengths,
142
+ ):
143
+
144
+ processed_text = []
145
+ for sample in text:
146
+ positions = []
147
+ special_tokens = [re.escape(tok) for tok in [self.audio_token]]
148
+ pattern = "|".join(special_tokens)
149
+ positions = sorted([(match.start(), match.group()) for match in re.finditer(pattern, sample)])
150
+ positions.sort(key=lambda x: x[0])
151
+
152
+ for _, special_token in positions:
153
+ if special_token == self.audio_token:
154
+ sample = sample.replace(self.audio_token, "<|audio_placeholder|>" * next(audio_lengths), 1)
155
+
156
+ sample = sample.replace("<|audio_placeholder|>", self.audio_token)
157
+ processed_text.append(sample)
158
+ return processed_text
159
+
160
+ def get_chunked_index(self, token_indices: np.ndarray, tokens_per_chunk: int) -> list[tuple[int, int]]:
161
+ """
162
+ Splits token index list into chunks based on token value ranges.
163
+
164
+ Given a list of token indices, returns a list of (start, end) index tuples representing
165
+ slices of the list where the token values fall within successive ranges of `t_ntoken_per_chunk`.
166
+
167
+ For example, if `t_ntoken_per_chunk` is 1000, the function will create chunks such that:
168
+ - the first chunk contains token values < 1000,
169
+ - the second chunk contains values >= 1000 and < 2000, and so on.
170
+
171
+ Parameters:
172
+ token_indices (`np.ndarray`): A monotonically increasing list of token index values.
173
+ t_ntoken_per_chunk (`int`): Number of tokens per chunk (used as the chunk size threshold).
174
+
175
+ Returns:
176
+ `list[tuple[int, int]]`: A list of tuples, each representing the start (inclusive)
177
+ and end (exclusive) indices of a chunk in `token_indices`.
178
+ """
179
+
180
+ def _iter():
181
+ i, start_idx = 0, 0 # skip bos token
182
+ current_chunk = 1
183
+ while i < len(token_indices): # skip eos token
184
+ if token_indices[i] >= current_chunk * tokens_per_chunk:
185
+ yield (start_idx, i)
186
+ start_idx = i
187
+ current_chunk += 1
188
+ i += 1
189
+ yield (start_idx, len(token_indices))
190
+
191
+ return list(_iter())
192
+
193
+ def apply_chat_template(self, conversations, chat_template=None, **kwargs):
194
+ return super().apply_chat_template(conversations, chat_template, **kwargs)
195
+
196
+ @property
197
+ def model_input_names(self):
198
+ tokenizer_input_names = self.tokenizer.model_input_names
199
+ feature_extractor_input_names = self.feature_extractor.model_input_names
200
+ return list(
201
+ dict.fromkeys(
202
+ tokenizer_input_names
203
+ + feature_extractor_input_names
204
+ + ["feature_attention_mask"]
205
+ )
206
+ )
207
+
208
+
209
+ __all__ = ["Qwen3ASRProcessor"]
qwen_asr/core/vllm_backend/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Alibaba Qwen team.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ from .qwen3_asr import Qwen3ASRForConditionalGeneration
qwen_asr/core/vllm_backend/qwen3_asr.py ADDED
@@ -0,0 +1,997 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ # Copyright 2026 The Qwen team.
4
+ # Copyright 2023 The vLLM team.
5
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
6
+ #
7
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
8
+ # and OPT implementations in this library. It has been modified from its
9
+ # original forms to accommodate minor architectural differences compared
10
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
11
+ #
12
+ # Licensed under the Apache License, Version 2.0 (the "License");
13
+ # you may not use this file except in compliance with the License.
14
+ # You may obtain a copy of the License at
15
+ #
16
+ # http://www.apache.org/licenses/LICENSE-2.0
17
+ #
18
+ # Unless required by applicable law or agreed to in writing, software
19
+ # distributed under the License is distributed on an "AS IS" BASIS,
20
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21
+ # See the License for the specific language governing permissions and
22
+ # limitations under the License.
23
+ """Inference-only Qwen3-ASR model."""
24
+
25
+ from collections.abc import Iterable, Mapping, Sequence
26
+ from typing import Any, Literal, cast
27
+
28
+ import numpy as np
29
+ import torch
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+ from transformers.feature_extraction_utils import BatchFeature
33
+ from transformers.models.whisper import WhisperFeatureExtractor
34
+
35
+ from vllm.config import MultiModalConfig, ModelConfig, SpeechToTextConfig, VllmConfig
36
+ from vllm.config.multimodal import BaseDummyOptions
37
+ from vllm.distributed import get_tensor_model_parallel_world_size
38
+ from vllm.inputs.data import PromptType
39
+ from vllm.logger import init_logger
40
+ from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
41
+ from vllm.model_executor.layers.attention.mm_encoder_attention import (
42
+ MMEncoderAttention,
43
+ )
44
+ from vllm.model_executor.layers.linear import (
45
+ ColumnParallelLinear,
46
+ QKVParallelLinear,
47
+ RowParallelLinear,
48
+ )
49
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
50
+ from vllm.model_executor.models.interfaces import (
51
+ MultiModalEmbeddings,
52
+ SupportsMRoPE,
53
+ SupportsMultiModal,
54
+ SupportsPP,
55
+ SupportsTranscription,
56
+ )
57
+ from vllm.model_executor.models.module_mapping import MultiModelKeys
58
+ from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM
59
+ from vllm.model_executor.models.qwen3_omni_moe_thinker import (
60
+ Qwen2_5OmniAudioFeatureInputs,
61
+ Qwen3OmniMoeThinkerMultiModalProcessor,
62
+ )
63
+ from vllm.model_executor.models.utils import (
64
+ AutoWeightsLoader,
65
+ WeightsMapper,
66
+ _merge_multimodal_embeddings,
67
+ maybe_prefix,
68
+ )
69
+ from vllm.model_executor.models.whisper import ISO639_1_SUPPORTED_LANGS
70
+ from vllm.multimodal import MULTIMODAL_REGISTRY
71
+ from vllm.multimodal.inputs import (
72
+ AudioItem,
73
+ ModalityData,
74
+ MultiModalDataDict,
75
+ MultiModalFeatureSpec,
76
+ MultiModalFieldConfig,
77
+ MultiModalKwargsItems,
78
+ )
79
+ from vllm.multimodal.parse import (
80
+ AudioProcessorItems,
81
+ DictEmbeddingItems,
82
+ ModalityDataItems,
83
+ MultiModalDataItems,
84
+ MultiModalDataParser,
85
+ )
86
+ from vllm.multimodal.processing import (
87
+ BaseProcessingInfo,
88
+ PromptReplacement,
89
+ PromptUpdate,
90
+ )
91
+ from vllm.sequence import IntermediateTensors
92
+ from vllm.v1.attention.backends.registry import AttentionBackendEnum
93
+ from vllm.tokenizers import cached_tokenizer_from_config
94
+ from vllm.transformers_utils.processor import cached_processor_from_config
95
+ from vllm.model_executor.models.vision import (
96
+ get_vit_attn_backend,
97
+ )
98
+ from ..transformers_backend.configuration_qwen3_asr import (
99
+ Qwen3ASRConfig,
100
+ Qwen3ASRThinkerConfig,
101
+ Qwen3ASRAudioEncoderConfig
102
+ )
103
+ from ..transformers_backend.processing_qwen3_asr import (
104
+ Qwen3ASRProcessor,
105
+ )
106
+
107
+ try:
108
+ from vllm.multimodal.profiling import BaseDummyInputsBuilder
109
+ except:
110
+ from vllm.multimodal.processing import BaseDummyInputsBuilder
111
+
112
+ logger = init_logger(__name__)
113
+
114
+
115
+ def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
116
+ input_lengths_leave = input_lengths % 100
117
+ feat_lengths = (input_lengths_leave - 1) // 2 + 1
118
+ output_lengths = (
119
+ ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
120
+ )
121
+ return output_lengths
122
+
123
+
124
+ # ============= Audio Encoder Components =============
125
+
126
+
127
+ class SinusoidsPositionEmbedding(nn.Module):
128
+ """Sinusoidal position embedding for audio encoder."""
129
+
130
+ def __init__(self, length: int, channels: int, max_timescale: int = 10000):
131
+ super().__init__()
132
+ self.length = length
133
+ self.channels = channels
134
+ self.max_timescale = max_timescale
135
+
136
+ if channels % 2 != 0:
137
+ raise ValueError("SinusoidsPositionEmbedding needs even channels input")
138
+
139
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
140
+ inv_timescales = torch.exp(
141
+ -log_timescale_increment * torch.arange(channels // 2).float()
142
+ )
143
+ scaled_time = (
144
+ torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
145
+ )
146
+ positional_embedding = torch.cat(
147
+ [torch.sin(scaled_time), torch.cos(scaled_time)], dim=1
148
+ )
149
+ self.register_buffer(
150
+ "positional_embedding", positional_embedding, persistent=False
151
+ )
152
+
153
+ def forward(self, seqlen: int) -> torch.Tensor:
154
+ return self.positional_embedding[:seqlen, :]
155
+
156
+
157
+ class Qwen3ASRAudioAttention(nn.Module):
158
+ """Multi-headed attention for Qwen3-Omni Audio Encoder using MMEncoderAttention."""
159
+
160
+ def __init__(
161
+ self,
162
+ config: Qwen3ASRAudioEncoderConfig,
163
+ multimodal_config: MultiModalConfig | None = None,
164
+ prefix: str = "",
165
+ ):
166
+ super().__init__()
167
+ self.embed_dim = config.d_model
168
+ self.num_heads = config.encoder_attention_heads
169
+ self.head_dim = self.embed_dim // self.num_heads
170
+ tp_size = get_tensor_model_parallel_world_size()
171
+ self.num_local_heads = self.num_heads // tp_size
172
+
173
+ if (self.head_dim * self.num_heads) != self.embed_dim:
174
+ raise ValueError(
175
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: "
176
+ f"{self.embed_dim} and `num_heads`: {self.num_heads})."
177
+ )
178
+
179
+ self.scaling = self.head_dim**-0.5
180
+
181
+ self.qkv = QKVParallelLinear(
182
+ hidden_size=self.embed_dim,
183
+ head_size=self.head_dim,
184
+ total_num_heads=self.num_heads,
185
+ total_num_kv_heads=self.num_heads,
186
+ bias=True,
187
+ prefix=f"{prefix}.qkv",
188
+ )
189
+
190
+ self.out_proj = RowParallelLinear(
191
+ input_size=self.embed_dim,
192
+ output_size=self.embed_dim,
193
+ bias=True,
194
+ prefix=f"{prefix}.out_proj",
195
+ )
196
+
197
+ self.attn = MMEncoderAttention(
198
+ num_heads=self.num_local_heads,
199
+ head_size=self.head_dim,
200
+ scale=self.scaling,
201
+ multimodal_config=multimodal_config,
202
+ )
203
+
204
+ def forward(
205
+ self,
206
+ hidden_states: torch.Tensor,
207
+ cu_seqlens: torch.Tensor,
208
+ max_seqlen: torch.Tensor | None,
209
+ ) -> torch.Tensor:
210
+ seq_length, _ = hidden_states.size()
211
+ qkv, _ = self.qkv(hidden_states)
212
+ q, k, v = qkv.chunk(3, dim=-1)
213
+ q = q.view(1, seq_length, -1, self.head_dim)
214
+ k = k.view(1, seq_length, -1, self.head_dim)
215
+ v = v.view(1, seq_length, -1, self.head_dim)
216
+
217
+ attn_output = self.attn(
218
+ query=q,
219
+ key=k,
220
+ value=v,
221
+ cu_seqlens=cu_seqlens,
222
+ max_seqlen=max_seqlen,
223
+ )
224
+
225
+ attn_output = attn_output.view(seq_length, -1)
226
+ output, _ = self.out_proj(attn_output)
227
+ return output
228
+
229
+
230
+ class Qwen3ASRAudioEncoderLayer(nn.Module):
231
+ """Transformer encoder layer for Qwen3-Omni Audio Encoder."""
232
+
233
+ def __init__(
234
+ self,
235
+ config: Qwen3ASRAudioEncoderConfig,
236
+ multimodal_config: MultiModalConfig | None = None,
237
+ prefix: str = "",
238
+ ):
239
+ super().__init__()
240
+ self.embed_dim = config.d_model
241
+ self.self_attn = Qwen3ASRAudioAttention(
242
+ config, multimodal_config=multimodal_config, prefix=f"{prefix}.self_attn"
243
+ )
244
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
245
+ self.activation_fn = _ACTIVATION_REGISTRY[config.activation_function]
246
+ self.fc1 = ColumnParallelLinear(
247
+ self.embed_dim,
248
+ config.encoder_ffn_dim,
249
+ bias=True,
250
+ prefix=f"{prefix}.fc1",
251
+ )
252
+ self.fc2 = RowParallelLinear(
253
+ config.encoder_ffn_dim,
254
+ self.embed_dim,
255
+ bias=True,
256
+ prefix=f"{prefix}.fc2",
257
+ )
258
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
259
+
260
+ def forward(
261
+ self,
262
+ hidden_states: torch.Tensor,
263
+ cu_seqlens: torch.Tensor,
264
+ max_seqlen: torch.Tensor | None,
265
+ ) -> torch.Tensor:
266
+ """
267
+ Args:
268
+ hidden_states: Input tensor of shape (seq_len, hidden_size)
269
+ cu_seqlens: Cumulative sequence lengths
270
+ max_seqlen: Maximum sequence length in the batch
271
+ """
272
+ residual = hidden_states
273
+ hidden_states = self.self_attn_layer_norm(hidden_states)
274
+ hidden_states = self.self_attn(
275
+ hidden_states=hidden_states,
276
+ cu_seqlens=cu_seqlens,
277
+ max_seqlen=max_seqlen,
278
+ )
279
+ hidden_states = residual + hidden_states
280
+
281
+ residual = hidden_states
282
+ hidden_states = self.final_layer_norm(hidden_states)
283
+ hidden_states, _ = self.fc1(hidden_states)
284
+ hidden_states = self.activation_fn(hidden_states)
285
+ hidden_states, _ = self.fc2(hidden_states)
286
+ hidden_states = residual + hidden_states
287
+
288
+ # Clamp for numerical stability with fp16
289
+ if hidden_states.dtype == torch.float16:
290
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
291
+ hidden_states = torch.clamp(
292
+ hidden_states, min=-clamp_value, max=clamp_value
293
+ )
294
+
295
+ return hidden_states
296
+
297
+
298
+ class Qwen3ASRAudioEncoder(nn.Module):
299
+ """vLLM-native Qwen3-ASR Audio Encoder."""
300
+
301
+ def __init__(
302
+ self,
303
+ config: Qwen3ASRAudioEncoderConfig,
304
+ multimodal_config: MultiModalConfig | None = None,
305
+ prefix: str = "",
306
+ ):
307
+ super().__init__()
308
+
309
+ embed_dim = config.d_model
310
+ self.num_mel_bins = config.num_mel_bins
311
+ self.max_source_positions = config.max_source_positions
312
+ self.n_window = config.n_window
313
+ self.n_window_infer = config.n_window_infer
314
+ self.conv_chunksize = config.conv_chunksize
315
+
316
+ # Position embedding
317
+ self.positional_embedding = SinusoidsPositionEmbedding(
318
+ self.max_source_positions, embed_dim
319
+ )
320
+
321
+ # Convolutional layers for mel-spectrogram processing
322
+ self.conv2d1 = nn.Conv2d(1, config.downsample_hidden_size, 3, 2, padding=1)
323
+ self.conv2d2 = nn.Conv2d(
324
+ config.downsample_hidden_size,
325
+ config.downsample_hidden_size,
326
+ 3,
327
+ 2,
328
+ padding=1,
329
+ )
330
+ self.conv2d3 = nn.Conv2d(
331
+ config.downsample_hidden_size,
332
+ config.downsample_hidden_size,
333
+ 3,
334
+ 2,
335
+ padding=1,
336
+ )
337
+
338
+ conv_out_dim = config.downsample_hidden_size * (
339
+ (((config.num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2
340
+ )
341
+ self.conv_out = nn.Linear(conv_out_dim, config.d_model, bias=False)
342
+
343
+ # Transformer encoder layers
344
+ self.layers = nn.ModuleList(
345
+ [
346
+ Qwen3ASRAudioEncoderLayer(
347
+ config,
348
+ multimodal_config=multimodal_config,
349
+ prefix=f"{prefix}.layers.{i}",
350
+ )
351
+ for i in range(config.encoder_layers)
352
+ ]
353
+ )
354
+
355
+ # Output layers
356
+ self.ln_post = nn.LayerNorm(config.d_model)
357
+ self.proj1 = nn.Linear(config.d_model, config.d_model)
358
+ self.act = _ACTIVATION_REGISTRY[config.activation_function]
359
+ self.proj2 = nn.Linear(config.d_model, config.output_dim)
360
+
361
+ # Get attention backend
362
+ attn_backend_override = (
363
+ multimodal_config.mm_encoder_attn_backend
364
+ if multimodal_config is not None
365
+ else None
366
+ )
367
+ self.attn_backend = get_vit_attn_backend(
368
+ head_size=config.d_model // config.encoder_attention_heads,
369
+ dtype=torch.get_default_dtype(),
370
+ attn_backend_override=attn_backend_override,
371
+ )
372
+
373
+ def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> torch.Tensor | None:
374
+ """Compute max_seqlen only for flash attention backends."""
375
+ max_seqlen = None
376
+ if self.attn_backend in {
377
+ AttentionBackendEnum.FLASH_ATTN,
378
+ AttentionBackendEnum.ROCM_AITER_FA,
379
+ }:
380
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
381
+ return max_seqlen
382
+
383
+ @property
384
+ def dtype(self) -> torch.dtype:
385
+ return self.conv2d1.weight.dtype
386
+
387
+ @property
388
+ def device(self) -> torch.device:
389
+ return self.conv2d1.weight.device
390
+
391
+ def forward(
392
+ self,
393
+ input_features: torch.Tensor,
394
+ feature_lens: torch.Tensor,
395
+ aftercnn_lens: torch.Tensor,
396
+ ):
397
+ # Compute chunk information
398
+ chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long()
399
+
400
+ chunk_lengths = torch.tensor(
401
+ [self.n_window * 2] * chunk_num.sum(),
402
+ dtype=torch.long,
403
+ device=feature_lens.device,
404
+ )
405
+ tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:]
406
+ chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2)
407
+ chunk_lengths[chunk_lengths == 0] = self.n_window * 2
408
+
409
+ # Split input features into chunks and pad
410
+ chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0)
411
+ padded_feature = nn.utils.rnn.pad_sequence(
412
+ chunk_list, batch_first=True
413
+ ).transpose(1, 2)
414
+
415
+ # Compute feature lengths after CNN
416
+ feature_lens_after_cnn = self._get_cnn_output_lengths(chunk_lengths)
417
+ # Vectorized mask creation: avoid creating many small tensors
418
+ max_len_after_cnn = feature_lens_after_cnn.max().item()
419
+ indices = torch.arange(max_len_after_cnn, device=padded_feature.device)
420
+ padded_mask_after_cnn = indices.unsqueeze(0) < feature_lens_after_cnn.unsqueeze(
421
+ 1
422
+ )
423
+
424
+ # Add channel dimension for conv2d
425
+ padded_feature = padded_feature.unsqueeze(1)
426
+
427
+ # Apply convolutional layers (chunk if needed to avoid OOM)
428
+ if padded_feature.size(0) <= self.conv_chunksize:
429
+ # Fast path: no chunking needed
430
+ padded_embed = F.gelu(self.conv2d1(padded_feature))
431
+ padded_embed = F.gelu(self.conv2d2(padded_embed))
432
+ padded_embed = F.gelu(self.conv2d3(padded_embed))
433
+ else:
434
+ # Chunked processing to avoid OOM
435
+ padded_embeds = []
436
+ for chunk in padded_feature.split(self.conv_chunksize, dim=0):
437
+ padded_embed = F.gelu(self.conv2d1(chunk))
438
+ padded_embed = F.gelu(self.conv2d2(padded_embed))
439
+ padded_embed = F.gelu(self.conv2d3(padded_embed))
440
+ padded_embeds.append(padded_embed)
441
+ padded_embed = torch.cat(padded_embeds, dim=0)
442
+
443
+ # (batch, channels, freq, time) -> (batch, time, channels*freq)
444
+ b, c, f, t = padded_embed.size()
445
+ padded_embed = self.conv_out(
446
+ padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f)
447
+ )
448
+
449
+ # Add positional embedding
450
+ positional_embedding = (
451
+ self.positional_embedding.positional_embedding[: padded_embed.shape[1], :]
452
+ .unsqueeze(0)
453
+ .to(padded_embed.dtype)
454
+ )
455
+ padded_embed = padded_embed + positional_embedding
456
+
457
+ # Extract valid hidden states and compute cu_seqlens
458
+ hidden_states = padded_embed[padded_mask_after_cnn]
459
+
460
+ # Compute cumulative sequence lengths for chunked attention
461
+ cu_chunk_lens = [0]
462
+ window_aftercnn = padded_mask_after_cnn.shape[-1] * (
463
+ self.n_window_infer // (self.n_window * 2)
464
+ )
465
+ # Use tolist() for efficient batch conversion from tensor to Python
466
+ for cnn_len in aftercnn_lens.tolist():
467
+ num_full_chunks = cnn_len // window_aftercnn
468
+ remainder = cnn_len % window_aftercnn
469
+ cu_chunk_lens.extend([window_aftercnn] * num_full_chunks)
470
+ if remainder:
471
+ cu_chunk_lens.append(remainder)
472
+ cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(
473
+ -1, dtype=torch.int32
474
+ )
475
+
476
+ max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
477
+
478
+ # Apply transformer layers
479
+ for encoder_layer in self.layers:
480
+ hidden_states = encoder_layer(
481
+ hidden_states,
482
+ cu_seqlens,
483
+ max_seqlen,
484
+ )
485
+
486
+ # Apply output layers
487
+ hidden_states = self.ln_post(hidden_states)
488
+ hidden_states = self.proj1(hidden_states)
489
+ hidden_states = self.act(hidden_states)
490
+ hidden_states = self.proj2(hidden_states)
491
+
492
+ return hidden_states
493
+
494
+ def _get_cnn_output_lengths(self, input_lengths: torch.Tensor) -> torch.Tensor:
495
+ """Compute output lengths after the three conv2d layers."""
496
+ lengths = input_lengths
497
+ for _ in range(3):
498
+ lengths = (lengths - 1) // 2 + 1
499
+ return lengths
500
+
501
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
502
+ """Load weights with mapping from HuggingFace format."""
503
+ stacked_params_mapping = [
504
+ # (param_name, shard_name, shard_id)
505
+ ("self_attn.qkv.", "self_attn.q_proj.", "q"),
506
+ ("self_attn.qkv.", "self_attn.k_proj.", "k"),
507
+ ("self_attn.qkv.", "self_attn.v_proj.", "v"),
508
+ ]
509
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
510
+ loaded_params: set[str] = set()
511
+
512
+ for name, loaded_weight in weights:
513
+ for param_name, weight_name, shard_id in stacked_params_mapping:
514
+ if weight_name not in name:
515
+ continue
516
+ name = name.replace(weight_name, param_name)
517
+
518
+ param = params_dict[name]
519
+ weight_loader = param.weight_loader
520
+ weight_loader(param, loaded_weight, shard_id)
521
+ break
522
+ else:
523
+ param = params_dict.get(name)
524
+ if param is not None:
525
+ weight_loader = getattr(
526
+ param, "weight_loader", default_weight_loader
527
+ )
528
+ weight_loader(param, loaded_weight)
529
+ loaded_params.add(name)
530
+ return loaded_params
531
+
532
+
533
+ class Qwen3ASRProcessingInfo(BaseProcessingInfo):
534
+ def get_hf_config(self):
535
+ return self.ctx.get_hf_config(Qwen3ASRConfig).thinker_config
536
+
537
+ def get_hf_processor(self, **kwargs: object) -> Qwen3ASRProcessor:
538
+ processor = self.ctx.get_hf_processor(
539
+ Qwen3ASRProcessor,
540
+ use_fast=kwargs.pop("use_fast", True),
541
+ **kwargs,
542
+ )
543
+ if not hasattr(processor, "audio_token"):
544
+ processor.audio_token = "<|audio_pad|>"
545
+ return processor
546
+
547
+ def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
548
+ hf_processor = self.get_hf_processor(**kwargs)
549
+ feature_extractor = hf_processor.feature_extractor
550
+ assert isinstance(feature_extractor, WhisperFeatureExtractor)
551
+ return feature_extractor
552
+
553
+ def get_supported_mm_limits(self) -> Mapping[str, int | None]:
554
+ return {"audio": None}
555
+
556
+
557
+ class Qwen3ASRDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3ASRProcessingInfo]):
558
+ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
559
+ num_audios = mm_counts.get("audio", 0)
560
+
561
+ hf_processor = self.info.get_hf_processor()
562
+ audio_token = hf_processor.audio_token
563
+
564
+ return audio_token * num_audios
565
+
566
+ def get_dummy_mm_data(
567
+ self,
568
+ seq_len: int,
569
+ mm_counts: Mapping[str, int],
570
+ mm_options: Mapping[str, BaseDummyOptions] | None = None,
571
+ ) -> MultiModalDataDict:
572
+ num_audios = mm_counts.get("audio", 0)
573
+
574
+ feature_extractor = self.info.get_feature_extractor()
575
+
576
+ target_audio_length = (
577
+ min(
578
+ feature_extractor.chunk_length,
579
+ 30,
580
+ )
581
+ * feature_extractor.sampling_rate
582
+ )
583
+
584
+ audio_overrides = mm_options.get("audio") if mm_options else None
585
+
586
+ return {
587
+ "audio": self._get_dummy_audios(
588
+ length=target_audio_length,
589
+ num_audios=num_audios,
590
+ overrides=audio_overrides,
591
+ ),
592
+ }
593
+
594
+
595
+ def _qwen3asr_field_config(hf_inputs: Mapping[str, torch.Tensor]):
596
+ audio_feature_lengths = hf_inputs.get("audio_feature_lengths", torch.empty((0,)))
597
+ return dict(
598
+ input_audio_features=MultiModalFieldConfig.flat_from_sizes(
599
+ "audio", audio_feature_lengths, dim=1
600
+ ),
601
+ feature_attention_mask=MultiModalFieldConfig.batched("audio"),
602
+ audio_feature_lengths=MultiModalFieldConfig.batched("audio"),
603
+ )
604
+
605
+
606
+ class Qwen3ASRMultiModalDataParser(MultiModalDataParser):
607
+ def _parse_audio_data(
608
+ self,
609
+ data: dict[str, torch.Tensor] | ModalityData[AudioItem],
610
+ ) -> ModalityDataItems[Any, Any] | None:
611
+ if isinstance(data, dict):
612
+ return DictEmbeddingItems(
613
+ data,
614
+ modality="audio",
615
+ required_fields={"input_audio_features", "audio_feature_lengths"},
616
+ fields_factory=_qwen3asr_field_config,
617
+ )
618
+
619
+ return super()._parse_audio_data(data)
620
+
621
+
622
+ class Qwen3ASRMultiModalProcessor(
623
+ Qwen3OmniMoeThinkerMultiModalProcessor,
624
+ ):
625
+ def _get_data_parser(self) -> MultiModalDataParser:
626
+ feature_extractor = self.info.get_feature_extractor()
627
+ return Qwen3ASRMultiModalDataParser(
628
+ target_sr=feature_extractor.sampling_rate,
629
+ )
630
+
631
+ def _get_mm_fields_config(
632
+ self,
633
+ hf_inputs: BatchFeature,
634
+ hf_processor_mm_kwargs: Mapping[str, object],
635
+ ) -> Mapping[str, MultiModalFieldConfig]:
636
+ return _qwen3asr_field_config(hf_inputs)
637
+
638
+ def _get_prompt_updates(
639
+ self,
640
+ mm_items: MultiModalDataItems,
641
+ hf_processor_mm_kwargs: Mapping[str, Any],
642
+ out_mm_kwargs: MultiModalKwargsItems,
643
+ ) -> Sequence[PromptUpdate]:
644
+ processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
645
+ tokenizer = self.info.get_tokenizer()
646
+ vocab = tokenizer.get_vocab()
647
+
648
+ audio_token = processor.audio_token
649
+ audio_token_id = vocab[audio_token]
650
+
651
+ out_mm_data = out_mm_kwargs.get_data()
652
+ audio_feature_lengths = out_mm_data.get("audio_feature_lengths")
653
+ feature_attention_mask = out_mm_data.get("feature_attention_mask")
654
+ if audio_feature_lengths is None and feature_attention_mask is None:
655
+ audio_output_lengths = []
656
+ elif audio_feature_lengths is not None:
657
+ audio_output_lens = _get_feat_extract_output_lengths(audio_feature_lengths)
658
+ audio_output_lengths = audio_output_lens.tolist()
659
+ elif feature_attention_mask is not None:
660
+ assert isinstance(feature_attention_mask, torch.Tensor)
661
+ audio_output_lens = _get_feat_extract_output_lengths(
662
+ feature_attention_mask.sum(-1)
663
+ )
664
+ audio_output_lengths = audio_output_lens.tolist()
665
+
666
+ def get_replacement_qwen2_audio(item_idx: int):
667
+ num_features = audio_output_lengths[item_idx]
668
+ if num_features == 0:
669
+ audios = mm_items.get_items("audio", AudioProcessorItems)
670
+ audio = audios.get(item_idx)
671
+ raise ValueError(
672
+ f"The audio {audio} (len={len(audio)}) is too short "
673
+ "to be represented inside the model"
674
+ )
675
+
676
+ return [audio_token_id] * num_features
677
+
678
+ return [
679
+ PromptReplacement(
680
+ modality="audio",
681
+ target=audio_token,
682
+ replacement=get_replacement_qwen2_audio,
683
+ ),
684
+ ]
685
+
686
+
687
+ @MULTIMODAL_REGISTRY.register_processor(
688
+ Qwen3ASRMultiModalProcessor,
689
+ info=Qwen3ASRProcessingInfo,
690
+ dummy_inputs=Qwen3ASRDummyInputsBuilder,
691
+ )
692
+ class Qwen3ASRForConditionalGeneration(
693
+ nn.Module,
694
+ SupportsMultiModal,
695
+ SupportsPP,
696
+ SupportsMRoPE,
697
+ SupportsTranscription,
698
+ ):
699
+ supported_languages = ISO639_1_SUPPORTED_LANGS
700
+
701
+ hf_to_vllm_mapper = WeightsMapper(
702
+ orig_to_new_prefix={
703
+ "thinker.lm_head.": "language_model.lm_head.",
704
+ "thinker.model.": "language_model.model.",
705
+ "thinker.": "",
706
+ }
707
+ )
708
+
709
+ @classmethod
710
+ def get_placeholder_str(cls, modality: str, i: int) -> str | None:
711
+ if modality.startswith("audio"):
712
+ return "<|audio_start|><|audio_pad|><|audio_end|>"
713
+
714
+ raise ValueError("Only audio modality is supported")
715
+
716
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
717
+ super().__init__()
718
+ self.vllm_config = vllm_config # needed for torch compile forward context
719
+ thinker_config: Qwen3ASRThinkerConfig = (
720
+ vllm_config.model_config.hf_config.thinker_config
721
+ )
722
+ quant_config = vllm_config.quant_config
723
+ multimodal_config = vllm_config.model_config.multimodal_config
724
+ self.config = thinker_config
725
+ self.multimodal_config = multimodal_config
726
+
727
+ self.audio_tower = Qwen3ASRAudioEncoder(
728
+ thinker_config.audio_config,
729
+ multimodal_config=multimodal_config,
730
+ prefix=maybe_prefix(prefix, "audio_tower"),
731
+ )
732
+ self.quant_config = quant_config
733
+
734
+ self.language_model = Qwen3ForCausalLM(
735
+ vllm_config=vllm_config.with_hf_config(
736
+ thinker_config.text_config, architectures=["Qwen3ForCausalLM"]
737
+ ),
738
+ prefix=maybe_prefix(prefix, "language_model"),
739
+ )
740
+
741
+ self.make_empty_intermediate_tensors = (
742
+ self.language_model.make_empty_intermediate_tensors
743
+ )
744
+
745
+ def _parse_and_validate_audio_input(
746
+ self, **kwargs: object
747
+ ) -> Qwen2_5OmniAudioFeatureInputs | None:
748
+ input_audio_features = kwargs.pop("input_audio_features", None)
749
+ audio_feature_lengths = kwargs.pop("audio_feature_lengths", None)
750
+ feature_attention_mask = kwargs.pop("feature_attention_mask", None)
751
+ if input_audio_features is None:
752
+ return None
753
+
754
+ return Qwen2_5OmniAudioFeatureInputs(
755
+ type="audio_features",
756
+ input_features=input_audio_features,
757
+ audio_feature_lengths=audio_feature_lengths,
758
+ feature_attention_mask=feature_attention_mask,
759
+ )
760
+
761
+ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
762
+ mm_input_by_modality = {}
763
+
764
+ # Preserve the order of modalities if there are multiple of them
765
+ # from the order of kwargs.
766
+ for input_key in kwargs:
767
+ if (
768
+ input_key in ("input_audio_features")
769
+ and "audio" not in mm_input_by_modality
770
+ ):
771
+ mm_input_by_modality["audio"] = self._parse_and_validate_audio_input(
772
+ **kwargs
773
+ )
774
+ return mm_input_by_modality
775
+
776
+ def _process_audio_input(
777
+ self,
778
+ audio_input: Qwen2_5OmniAudioFeatureInputs,
779
+ audio_hashes: list[str] | None = None,
780
+ cached_audio_features: torch.Tensor | None = None,
781
+ ) -> torch.Tensor:
782
+ input_features = audio_input["input_features"]
783
+ audio_feature_lengths = audio_input["audio_feature_lengths"]
784
+
785
+ audio_output_lengths = _get_feat_extract_output_lengths(audio_feature_lengths)
786
+
787
+ audio_features = self.audio_tower(
788
+ input_features.to(self.audio_tower.dtype),
789
+ feature_lens=audio_feature_lengths,
790
+ aftercnn_lens=audio_output_lengths,
791
+ )
792
+ return audio_features.split(audio_output_lengths.tolist())
793
+
794
+ def get_language_model(self) -> torch.nn.Module:
795
+ return self.language_model
796
+
797
+ def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
798
+ mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
799
+ if not mm_input_by_modality:
800
+ return []
801
+
802
+ # The result multimodal_embeddings is tuple of tensors, with each
803
+ # tensor correspoending to a multimodal data item (image or video).
804
+ multimodal_embeddings: tuple[torch.Tensor, ...] = ()
805
+
806
+ # NOTE: It is important to iterate over the keys in this dictionary
807
+ # to preserve the order of the modalities.
808
+ for modality in mm_input_by_modality:
809
+ multimodal_input = mm_input_by_modality[modality]
810
+ if modality == "audio":
811
+ audio_embeddings = self._process_audio_input(multimodal_input)
812
+ multimodal_embeddings += tuple(audio_embeddings)
813
+ return multimodal_embeddings
814
+
815
+ def embed_input_ids(
816
+ self,
817
+ input_ids: torch.Tensor,
818
+ multimodal_embeddings: MultiModalEmbeddings | None = None,
819
+ *,
820
+ is_multimodal: torch.Tensor | None = None,
821
+ handle_oov_mm_token: bool = False,
822
+ ) -> torch.Tensor:
823
+ inputs_embeds = self._embed_text_input_ids(
824
+ input_ids,
825
+ self.language_model.embed_input_ids,
826
+ is_multimodal=is_multimodal,
827
+ handle_oov_mm_token=handle_oov_mm_token,
828
+ )
829
+
830
+ if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
831
+ return inputs_embeds
832
+
833
+ inputs_embeds = _merge_multimodal_embeddings(
834
+ inputs_embeds=inputs_embeds,
835
+ multimodal_embeddings=multimodal_embeddings,
836
+ is_multimodal=is_multimodal,
837
+ )
838
+
839
+ return inputs_embeds
840
+
841
+ def forward(
842
+ self,
843
+ input_ids: torch.Tensor,
844
+ positions: torch.Tensor,
845
+ intermediate_tensors: IntermediateTensors | None = None,
846
+ inputs_embeds: torch.Tensor | None = None,
847
+ **kwargs: object,
848
+ ) -> torch.Tensor | IntermediateTensors:
849
+ if intermediate_tensors is not None:
850
+ inputs_embeds = None
851
+
852
+ hidden_states = self.language_model.model(
853
+ input_ids,
854
+ positions,
855
+ intermediate_tensors,
856
+ inputs_embeds=inputs_embeds,
857
+ )
858
+
859
+ return hidden_states
860
+
861
+ def compute_logits(
862
+ self,
863
+ hidden_states: torch.Tensor,
864
+ ) -> torch.Tensor | None:
865
+ return self.language_model.compute_logits(hidden_states)
866
+
867
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
868
+ loader = AutoWeightsLoader(
869
+ self,
870
+ skip_prefixes=["talker.", "code2wav."],
871
+ )
872
+ loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
873
+
874
+ return loaded_weights
875
+
876
+ def get_mrope_input_positions(
877
+ self,
878
+ input_tokens: list[int],
879
+ mm_features: list[MultiModalFeatureSpec],
880
+ ) -> tuple[torch.Tensor, int]:
881
+ seq_len = len(input_tokens)
882
+
883
+ if not mm_features:
884
+ # No audio features, just return linear positions
885
+ llm_positions = (
886
+ torch.arange(seq_len, dtype=torch.long).view(1, -1).expand(3, -1)
887
+ )
888
+ return llm_positions.clone(), 0
889
+
890
+ llm_pos_ids_list: list[torch.Tensor] = []
891
+ st = 0
892
+
893
+ for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
894
+ offset = mm_feature.mm_position.offset
895
+
896
+ # Get audio feature length from mm_feature data
897
+ audio_feature_length = mm_feature.data["audio_feature_lengths"].data
898
+ if isinstance(audio_feature_length, torch.Tensor):
899
+ audio_feature_length = audio_feature_length.item()
900
+ audio_len = _get_feat_extract_output_lengths(
901
+ torch.tensor(audio_feature_length)
902
+ ).item()
903
+
904
+ # Text segment before audio (includes audio_start token)
905
+ text_len = offset - st
906
+ st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
907
+ text_positions = (
908
+ torch.arange(text_len, dtype=torch.long).view(1, -1).expand(3, -1)
909
+ + st_idx
910
+ )
911
+ llm_pos_ids_list.append(text_positions)
912
+ st_idx = st_idx + text_len
913
+
914
+ # Audio token segment
915
+ audio_positions = (
916
+ torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1)
917
+ + st_idx
918
+ )
919
+ llm_pos_ids_list.append(audio_positions)
920
+
921
+ st = offset + audio_len
922
+
923
+ # Handle remaining text (includes audio_end and any trailing text)
924
+ if st < seq_len:
925
+ st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
926
+ text_len = seq_len - st
927
+ final_text_positions = (
928
+ torch.arange(text_len, dtype=torch.long).view(1, -1).expand(3, -1)
929
+ + st_idx
930
+ )
931
+ llm_pos_ids_list.append(final_text_positions)
932
+
933
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
934
+ if llm_positions.shape[1] != seq_len:
935
+ raise RuntimeError("Position ids length mismatch with input ids length")
936
+
937
+ mrope_position_delta = (llm_positions.max() + 1 - seq_len).item()
938
+ return llm_positions, mrope_position_delta
939
+
940
+ def get_mm_mapping(self) -> MultiModelKeys:
941
+ """
942
+ Get the module prefix in multimodal models
943
+ """
944
+ return MultiModelKeys.from_string_field(
945
+ language_model="language_model",
946
+ tower_model=["audio_tower."],
947
+ )
948
+
949
+ @classmethod
950
+ def get_speech_to_text_config(
951
+ cls, model_config: ModelConfig, task_type: str
952
+ ) -> SpeechToTextConfig:
953
+ processor = cached_processor_from_config(model_config)
954
+ feature_extractor: WhisperFeatureExtractor = processor.feature_extractor
955
+ return SpeechToTextConfig(
956
+ max_audio_clip_s=feature_extractor.chunk_length,
957
+ sample_rate=feature_extractor.sampling_rate,
958
+ )
959
+
960
+ @classmethod
961
+ def get_generation_prompt(
962
+ cls,
963
+ audio: np.ndarray,
964
+ model_config: ModelConfig,
965
+ stt_config: SpeechToTextConfig,
966
+ language: str | None,
967
+ task_type: Literal["transcribe", "translate"],
968
+ request_prompt: str,
969
+ to_language: str | None,
970
+ ) -> PromptType:
971
+ """Get the generation prompt to be used for transcription requests."""
972
+ tokenizer = cached_tokenizer_from_config(model_config)
973
+ audio_placeholder = cls.get_placeholder_str("audio", 0)
974
+
975
+ if task_type not in ("transcribe", "translate"):
976
+ raise ValueError(
977
+ f"Unsupported task_type '{task_type}'. "
978
+ "Supported task types are 'transcribe' and 'translate'."
979
+ )
980
+ full_lang_name_to = cls.supported_languages.get(to_language, to_language)
981
+ if to_language is None:
982
+ prompt = (
983
+ f"<|im_start|>user\n{audio_placeholder}<|im_end|>\n"
984
+ f"<|im_start|>assistant\n"
985
+ )
986
+ else:
987
+ prompt = (
988
+ f"<|im_start|>user\n{audio_placeholder}<|im_end|>\n"
989
+ f"<|im_start|>assistant\nlanguage {full_lang_name_to}<asr_text>"
990
+ )
991
+
992
+ prompt_token_ids = tokenizer.encode(prompt)
993
+ prompt_dict = {
994
+ "prompt_token_ids": prompt_token_ids,
995
+ "multi_modal_data": {"audio": audio},
996
+ }
997
+ return cast(PromptType, prompt_dict)
qwen_asr/inference/assets/korean_dict_jieba.dict ADDED
The diff for this file is too large to render. See raw diff
 
qwen_asr/inference/qwen3_asr.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Alibaba Qwen team.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ from dataclasses import dataclass
17
+ from typing import Any, Dict, List, Optional, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+ from qwen_asr.core.transformers_backend import (
22
+ Qwen3ASRConfig,
23
+ Qwen3ASRForConditionalGeneration,
24
+ Qwen3ASRProcessor,
25
+ )
26
+ from transformers import AutoConfig, AutoModel, AutoProcessor
27
+
28
+ AutoConfig.register("qwen3_asr", Qwen3ASRConfig)
29
+ AutoModel.register(Qwen3ASRConfig, Qwen3ASRForConditionalGeneration)
30
+ AutoProcessor.register(Qwen3ASRConfig, Qwen3ASRProcessor)
31
+
32
+ from .qwen3_forced_aligner import Qwen3ForcedAligner
33
+ from .utils import (
34
+ MAX_ASR_INPUT_SECONDS,
35
+ MAX_FORCE_ALIGN_INPUT_SECONDS,
36
+ SAMPLE_RATE,
37
+ SUPPORTED_LANGUAGES,
38
+ AudioChunk,
39
+ AudioLike,
40
+ chunk_list,
41
+ merge_languages,
42
+ normalize_audios,
43
+ normalize_language_name,
44
+ parse_asr_output,
45
+ split_audio_into_chunks,
46
+ validate_language,
47
+ )
48
+
49
+ try:
50
+ from qwen_asr.core.vllm_backend import Qwen3ASRForConditionalGeneration
51
+ from vllm import ModelRegistry
52
+ ModelRegistry.register_model("Qwen3ASRForConditionalGeneration", Qwen3ASRForConditionalGeneration)
53
+ except:
54
+ pass
55
+
56
+
57
+ @dataclass
58
+ class ASRTranscription:
59
+ """
60
+ One transcription result.
61
+
62
+ Attributes:
63
+ language (str):
64
+ Merged language string for the sample, e.g. "Chinese" or "Chinese,English".
65
+ Empty string if unknown or silent audio.
66
+ text (str):
67
+ Transcribed text.
68
+ time_stamps (Optional[Any]):
69
+ Forced aligner output (ForcedAlignResult).
70
+ Present only when return_time_stamps=True.
71
+ """
72
+ language: str
73
+ text: str
74
+ time_stamps: Optional[Any] = None
75
+
76
+
77
+ class Qwen3ASRModel:
78
+ """
79
+ Unified inference wrapper for Qwen3-ASR with two backends:
80
+ - Transformers backend
81
+ - vLLM backend
82
+
83
+ It optionally supports time stamp output via Qwen3-ForcedAligner.
84
+
85
+ Notes:
86
+ - Each request uses a context text and exactly one audio.
87
+ - If language is provided, the prompt will force the output to be text-only by appending
88
+ "language {Language}<asr_text>" to the assistant prompt.
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ backend: str,
94
+ model: Any,
95
+ processor: Any,
96
+ sampling_params: Optional[Any] = None,
97
+ forced_aligner: Optional[Qwen3ForcedAligner] = None,
98
+ max_inference_batch_size: int = -1,
99
+ ):
100
+ self.backend = backend # "transformers" | "vllm"
101
+ self.model = model
102
+ self.processor = processor
103
+ self.sampling_params = sampling_params
104
+ self.forced_aligner = forced_aligner
105
+ self.max_inference_batch_size = int(max_inference_batch_size)
106
+
107
+ if backend == "transformers":
108
+ self.device = getattr(model, "device", None)
109
+ if self.device is None:
110
+ try:
111
+ self.device = next(model.parameters()).device
112
+ except StopIteration:
113
+ self.device = torch.device("cpu")
114
+ self.dtype = getattr(model, "dtype", torch.float32)
115
+ else:
116
+ self.device = None
117
+ self.dtype = None
118
+
119
+ @classmethod
120
+ def from_pretrained(
121
+ cls,
122
+ pretrained_model_name_or_path: str,
123
+ forced_aligner: Optional[str] = None,
124
+ forced_aligner_kwargs: Optional[Dict[str, Any]] = None,
125
+ max_inference_batch_size: int = -1,
126
+ **kwargs,
127
+ ) -> "Qwen3ASRModel":
128
+ """
129
+ Initialize using Transformers backend.
130
+
131
+ Args:
132
+ pretrained_model_name_or_path:
133
+ HuggingFace repo id or local directory.
134
+ forced_aligner:
135
+ Optional forced aligner model path/repo id.
136
+ forced_aligner_kwargs:
137
+ Optional kwargs forwarded to Qwen3ForcedAligner.from_pretrained(...).
138
+ max_inference_batch_size:
139
+ Batch size limit for inference. -1 means no chunking. Small values can avoid OOM.
140
+ **kwargs:
141
+ Forwarded to AutoModel.from_pretrained(...).
142
+
143
+ Returns:
144
+ Qwen3ASRModel
145
+ """
146
+
147
+ model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
148
+
149
+ processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, fix_mistral_regex=True)
150
+
151
+ if forced_aligner is not None:
152
+ forced_aligner_model = Qwen3ForcedAligner.from_pretrained(
153
+ forced_aligner, **(forced_aligner_kwargs or {})
154
+ )
155
+
156
+ return cls(
157
+ backend="transformers",
158
+ model=model,
159
+ processor=processor,
160
+ sampling_params=None,
161
+ forced_aligner=forced_aligner_model,
162
+ max_inference_batch_size=max_inference_batch_size,
163
+ )
164
+
165
+ @classmethod
166
+ def LLM(
167
+ cls,
168
+ model: str,
169
+ forced_aligner: Optional[str] = None,
170
+ forced_aligner_kwargs: Optional[Dict[str, Any]] = None,
171
+ max_inference_batch_size: int = -1,
172
+ max_new_tokens: Optional[int] = 8192,
173
+ **kwargs,
174
+ ) -> "Qwen3ASRModel":
175
+ """
176
+ Initialize using vLLM backend.
177
+
178
+ Import is isolated to keep vLLM optional.
179
+
180
+ Args:
181
+ model:
182
+ Model path/repo for vLLM.
183
+ forced_aligner:
184
+ Optional forced aligner model path/repo id.
185
+ forced_aligner_kwargs:
186
+ Optional kwargs forwarded to Qwen3ForcedAligner.from_pretrained(...).
187
+ max_inference_batch_size:
188
+ Batch size limit for inference. -1 means no chunking. Small values can avoid OOM.
189
+ max_new_tokens:
190
+ Maximum number of tokens to generate.
191
+ **kwargs:
192
+ Forwarded to vllm.LLM(...).
193
+
194
+ Returns:
195
+ Qwen3ASRModel
196
+
197
+ Raises:
198
+ ImportError: If vLLM is not installed.
199
+ """
200
+ try:
201
+ from vllm import LLM as vLLM
202
+ from vllm import SamplingParams
203
+ except Exception as e:
204
+ raise ImportError(
205
+ "vLLM is not available. Install with: pip install qwen-asr[vllm]"
206
+ ) from e
207
+
208
+ llm = vLLM(model=model, **kwargs)
209
+
210
+ processor = Qwen3ASRProcessor.from_pretrained(model, fix_mistral_regex=True)
211
+ sampling_params = SamplingParams(**({"temperature": 0.0, "max_tokens": max_new_tokens}))
212
+
213
+ if forced_aligner is not None:
214
+ forced_aligner_model = Qwen3ForcedAligner.from_pretrained(
215
+ forced_aligner, **(forced_aligner_kwargs or {})
216
+ )
217
+
218
+ return cls(
219
+ backend="vllm",
220
+ model=llm,
221
+ processor=processor,
222
+ sampling_params=sampling_params,
223
+ forced_aligner=forced_aligner_model,
224
+ max_inference_batch_size=max_inference_batch_size,
225
+ )
226
+
227
+ def get_supported_languages(self) -> List[str]:
228
+ """
229
+ Returns the supported language list.
230
+
231
+ Returns:
232
+ List[str]: Canonical language names.
233
+ """
234
+ return list(SUPPORTED_LANGUAGES)
235
+
236
+ @torch.no_grad()
237
+ def transcribe(
238
+ self,
239
+ audio: Union[AudioLike, List[AudioLike]],
240
+ context: Union[str, List[str]] = "",
241
+ language: Optional[Union[str, List[Optional[str]]]] = None,
242
+ return_time_stamps: bool = False,
243
+ ) -> List[ASRTranscription]:
244
+ """
245
+ Transcribe audio with optional context and optional forced alignment timestamps.
246
+
247
+ Args:
248
+ audio:
249
+ Audio input(s). Supported:
250
+ - str: local path / URL / base64 data url
251
+ - (np.ndarray, sr)
252
+ - list of above
253
+ context:
254
+ Context string(s). If scalar, it will be broadcast to batch size.
255
+ language:
256
+ Optional language(s). If provided, it must be in supported languages.
257
+ If scalar, it will be broadcast to batch size.
258
+ If provided, the prompt will force output to be transcription text only.
259
+ return_time_stamps:
260
+ If True, timestamps are produced via forced aligner and merged across chunks.
261
+ This requires forced_aligner initialized.
262
+
263
+ Returns:
264
+ List[ASRTranscription]: One result per input audio.
265
+
266
+ Raises:
267
+ ValueError:
268
+ - If return_time_stamps=True but forced_aligner is not provided.
269
+ - If language is unsupported.
270
+ - If batch sizes mismatch for context/language.
271
+ """
272
+ if return_time_stamps and self.forced_aligner is None:
273
+ raise ValueError("return_time_stamps=True requires `forced_aligner` to be provided at initialization.")
274
+
275
+ wavs = normalize_audios(audio)
276
+ n = len(wavs)
277
+
278
+ ctxs = context if isinstance(context, list) else [context]
279
+ if len(ctxs) == 1 and n > 1:
280
+ ctxs = ctxs * n
281
+ if len(ctxs) != n:
282
+ raise ValueError(f"Batch size mismatch: audio={n}, context={len(ctxs)}")
283
+
284
+ langs_in: List[Optional[str]]
285
+ if language is None:
286
+ langs_in = [None] * n
287
+ else:
288
+ langs_in = language if isinstance(language, list) else [language]
289
+ if len(langs_in) == 1 and n > 1:
290
+ langs_in = langs_in * n
291
+ if len(langs_in) != n:
292
+ raise ValueError(f"Batch size mismatch: audio={n}, language={len(langs_in)}")
293
+
294
+ langs_norm: List[Optional[str]] = []
295
+ for l in langs_in:
296
+ if l is None or str(l).strip() == "":
297
+ langs_norm.append(None)
298
+ else:
299
+ ln = normalize_language_name(str(l))
300
+ validate_language(ln)
301
+ langs_norm.append(ln)
302
+
303
+ max_chunk_sec = MAX_FORCE_ALIGN_INPUT_SECONDS if return_time_stamps else MAX_ASR_INPUT_SECONDS
304
+
305
+ # chunk audios and record mapping
306
+ chunks: List[AudioChunk] = []
307
+ for i, wav in enumerate(wavs):
308
+ parts = split_audio_into_chunks(
309
+ wav=wav,
310
+ sr=SAMPLE_RATE,
311
+ max_chunk_sec=max_chunk_sec,
312
+ )
313
+ for j, (cwav, offset_sec) in enumerate(parts):
314
+ chunks.append(AudioChunk(orig_index=i, chunk_index=j, wav=cwav, sr=SAMPLE_RATE, offset_sec=offset_sec))
315
+
316
+ # run ASR on chunks
317
+ chunk_ctx: List[str] = [ctxs[c.orig_index] for c in chunks]
318
+ chunk_lang: List[Optional[str]] = [langs_norm[c.orig_index] for c in chunks]
319
+ chunk_wavs: List[np.ndarray] = [c.wav for c in chunks]
320
+ raw_outputs = self._infer_asr(chunk_ctx, chunk_wavs, chunk_lang)
321
+
322
+ # parse outputs, prepare for optional alignment
323
+ per_chunk_lang: List[str] = []
324
+ per_chunk_text: List[str] = []
325
+ for out, forced_lang in zip(raw_outputs, chunk_lang):
326
+ lang, txt = parse_asr_output(out, user_language=forced_lang)
327
+ per_chunk_lang.append(lang)
328
+ per_chunk_text.append(txt)
329
+
330
+ # forced alignment (optional)
331
+ per_chunk_align: List[Optional[Any]] = [None] * len(chunks)
332
+ if return_time_stamps:
333
+ to_align_audio = []
334
+ to_align_text = []
335
+ to_align_lang = []
336
+ to_align_idx = []
337
+
338
+ for idx, (c, txt, lang_pred) in enumerate(zip(chunks, per_chunk_text, per_chunk_lang)):
339
+ if txt.strip() == "":
340
+ continue
341
+ to_align_audio.append((c.wav, c.sr))
342
+ to_align_text.append(txt)
343
+ to_align_lang.append(lang_pred)
344
+ to_align_idx.append(idx)
345
+
346
+ # batch align with max_inference_batch_size
347
+ aligned_results: List[Any] = []
348
+ for a_chunk, t_chunk, l_chunk in zip(
349
+ chunk_list(to_align_audio, self.max_inference_batch_size),
350
+ chunk_list(to_align_text, self.max_inference_batch_size),
351
+ chunk_list(to_align_lang, self.max_inference_batch_size),
352
+ ):
353
+ aligned_results.extend(
354
+ self.forced_aligner.align(audio=a_chunk, text=t_chunk, language=l_chunk)
355
+ )
356
+
357
+ # offset fix
358
+ for k, idx in enumerate(to_align_idx):
359
+ c = chunks[idx]
360
+ r = aligned_results[k]
361
+ per_chunk_align[idx] = self._offset_align_result(r, c.offset_sec)
362
+
363
+ # merge chunks back to original samples
364
+ out_langs: List[List[str]] = [[] for _ in range(n)]
365
+ out_texts: List[List[str]] = [[] for _ in range(n)]
366
+ out_aligns: List[List[Any]] = [[] for _ in range(n)]
367
+
368
+ for c, lang, txt, al in zip(chunks, per_chunk_lang, per_chunk_text, per_chunk_align):
369
+ out_langs[c.orig_index].append(lang)
370
+ out_texts[c.orig_index].append(txt)
371
+ if return_time_stamps and al is not None:
372
+ out_aligns[c.orig_index].append(al)
373
+
374
+ results: List[ASRTranscription] = []
375
+ for i in range(n):
376
+ merged_text = "".join([t for t in out_texts[i] if t is not None])
377
+ merged_language = merge_languages(out_langs[i])
378
+ merged_align = None
379
+ if return_time_stamps:
380
+ merged_align = self._merge_align_results(out_aligns[i])
381
+ results.append(ASRTranscription(language=merged_language, text=merged_text, time_stamps=merged_align))
382
+
383
+ return results
384
+
385
+ def _build_messages(self, context: str, audio_payload: Any) -> List[Dict[str, Any]]:
386
+ return [
387
+ {"role": "system", "content": context or ""},
388
+ {"role": "user", "content": [{"type": "audio", "audio": audio_payload}]},
389
+ ]
390
+
391
+ def _build_text_prompt(self, context: str, force_language: Optional[str]) -> str:
392
+ """
393
+ Build the string prompt for one request.
394
+
395
+ If force_language is provided, "language X<asr_text>" is appended after the generation prompt
396
+ to request text-only output.
397
+ """
398
+ msgs = self._build_messages(context=context, audio_payload="")
399
+ base = self.processor.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False)
400
+ if force_language:
401
+ base = base + f"language {force_language}{'<asr_text>'}"
402
+ return base
403
+
404
+ def _infer_asr(
405
+ self,
406
+ contexts: List[str],
407
+ wavs: List[np.ndarray],
408
+ languages: List[Optional[str]],
409
+ ) -> List[str]:
410
+ """
411
+ Run backend inference for chunk-level items.
412
+
413
+ Args:
414
+ contexts: List of system context strings.
415
+ wavs: List of mono waveforms (np.ndarray).
416
+ languages: List of forced languages or None.
417
+
418
+ Returns:
419
+ List[str]: Raw decoded strings (one per chunk).
420
+ """
421
+ if self.backend == "transformers":
422
+ return self._infer_asr_transformers(contexts, wavs, languages)
423
+ if self.backend == "vllm":
424
+ return self._infer_asr_vllm(contexts, wavs, languages)
425
+ raise RuntimeError(f"Unknown backend: {self.backend}")
426
+
427
+ def _infer_asr_transformers(
428
+ self,
429
+ contexts: List[str],
430
+ wavs: List[np.ndarray],
431
+ languages: List[Optional[str]],
432
+ ) -> List[str]:
433
+ outs: List[str] = []
434
+
435
+ texts = [self._build_text_prompt(context=c, force_language=fl) for c, fl in zip(contexts, languages)]
436
+
437
+ batch_size = self.max_inference_batch_size
438
+ if batch_size is None or batch_size < 0:
439
+ batch_size = len(texts)
440
+
441
+ for i in range(0, len(texts), batch_size):
442
+ sub_text = texts[i : i + batch_size]
443
+ sub_wavs = wavs[i : i + batch_size]
444
+ inputs = self.processor(text=sub_text, audio=sub_wavs, return_tensors="pt", padding=True)
445
+ inputs = inputs.to(self.model.device).to(self.model.dtype)
446
+
447
+ text_ids = self.model.generate(**inputs)
448
+
449
+ decoded = self.processor.batch_decode(
450
+ text_ids.sequences[:, inputs["input_ids"].shape[1]:],
451
+ skip_special_tokens=True,
452
+ clean_up_tokenization_spaces=False,
453
+ )
454
+ outs.extend(list(decoded))
455
+
456
+ return outs
457
+
458
+ def _infer_asr_vllm(
459
+ self,
460
+ contexts: List[str],
461
+ wavs: List[np.ndarray],
462
+ languages: List[Optional[str]],
463
+ ) -> List[str]:
464
+ inputs: List[Dict[str, Any]] = []
465
+ for c, w, fl in zip(contexts, wavs, languages):
466
+ prompt = self._build_text_prompt(context=c, force_language=fl)
467
+ inputs.append({"prompt": prompt, "multi_modal_data": {"audio": [w]}})
468
+
469
+ outs: List[str] = []
470
+ for batch in chunk_list(inputs, self.max_inference_batch_size):
471
+ outputs = self.model.generate(batch, sampling_params=self.sampling_params, use_tqdm=False)
472
+ for o in outputs:
473
+ outs.append(o.outputs[0].text)
474
+ return outs
475
+
476
+ def _offset_align_result(self, result: Any, offset_sec: float) -> Any:
477
+ """
478
+ Apply time offset to a ForcedAlignResult-like object.
479
+
480
+ This function assumes:
481
+ - result has attribute `.items` which is a list of items with start_time/end_time in seconds.
482
+ - dataclasses are frozen in upstream implementation, so we reconstruct by type.
483
+
484
+ Args:
485
+ result: ForcedAlignResult
486
+ offset_sec: Offset in seconds
487
+
488
+ Returns:
489
+ ForcedAlignResult: New object with shifted timestamps.
490
+ """
491
+ if result is None:
492
+ return None
493
+ items = []
494
+ for it in result.items:
495
+ items.append(type(it)(text=it.text,
496
+ start_time=round(it.start_time + offset_sec, 3),
497
+ end_time=round(it.end_time + offset_sec, 3)))
498
+ return type(result)(items=items)
499
+
500
+ def _merge_align_results(self, results: List[Any]) -> Optional[Any]:
501
+ """
502
+ Merge multiple ForcedAlignResult objects into a single one by concatenating items.
503
+
504
+ Args:
505
+ results: List of ForcedAlignResult
506
+
507
+ Returns:
508
+ ForcedAlignResult or None
509
+ """
510
+ if not results:
511
+ return None
512
+ all_items = []
513
+ for r in results:
514
+ if r is None:
515
+ continue
516
+ all_items.extend(list(r.items))
517
+ if not all_items:
518
+ return None
519
+ return type(results[0])(items=all_items)
qwen_asr/inference/qwen3_forced_aligner.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Alibaba Qwen team.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import os
17
+ import unicodedata
18
+ from dataclasses import dataclass
19
+ from typing import Any, Dict, List, Optional, Tuple, Union
20
+
21
+ import nagisa
22
+ import numpy as np
23
+ import torch
24
+ from qwen_asr.core.transformers_backend import (
25
+ Qwen3ASRConfig,
26
+ Qwen3ASRForConditionalGeneration,
27
+ Qwen3ASRProcessor,
28
+ )
29
+ from transformers import AutoConfig, AutoModel, AutoProcessor
30
+
31
+ from .utils import (
32
+ AudioLike,
33
+ ensure_list,
34
+ normalize_audios,
35
+ )
36
+
37
+
38
+ class Qwen3ForceAlignProcessor():
39
+ def __init__(self):
40
+ ko_dict_path = os.path.join(os.path.dirname(__file__), "assets", "korean_dict_jieba.dict")
41
+ ko_scores = {}
42
+ with open(ko_dict_path, "r", encoding="utf-8") as f:
43
+ for line in f:
44
+ line = line.strip()
45
+ if not line:
46
+ continue
47
+ word = line.split()[0]
48
+ ko_scores[word] = 1.0
49
+ self.ko_score = ko_scores
50
+ self.ko_tokenizer = None
51
+
52
+ def is_kept_char(self, ch: str) -> bool:
53
+ if ch == "'":
54
+ return True
55
+ cat = unicodedata.category(ch)
56
+ if cat.startswith("L") or cat.startswith("N"):
57
+ return True
58
+ return False
59
+
60
+ def clean_token(self, token: str) -> str:
61
+ return "".join(ch for ch in token if self.is_kept_char(ch))
62
+
63
+ def is_cjk_char(self, ch: str) -> bool:
64
+ code = ord(ch)
65
+ return (
66
+ 0x4E00 <= code <= 0x9FFF # CJK Unified Ideographs
67
+ or 0x3400 <= code <= 0x4DBF # Extension A
68
+ or 0x20000 <= code <= 0x2A6DF # Extension B
69
+ or 0x2A700 <= code <= 0x2B73F # Extension C
70
+ or 0x2B740 <= code <= 0x2B81F # Extension D
71
+ or 0x2B820 <= code <= 0x2CEAF # Extension E
72
+ or 0xF900 <= code <= 0xFAFF # Compatibility Ideographs
73
+ )
74
+
75
+ def tokenize_chinese_mixed(self, text: str) -> List[str]:
76
+ tokens: List[str] = []
77
+ current_latin: List[str] = []
78
+
79
+ def flush_latin():
80
+ nonlocal current_latin
81
+ if current_latin:
82
+ token = "".join(current_latin)
83
+ cleaned = self.clean_token(token)
84
+ if cleaned:
85
+ tokens.append(cleaned)
86
+ current_latin = []
87
+
88
+ for ch in text:
89
+ if self.is_cjk_char(ch):
90
+ flush_latin()
91
+ tokens.append(ch)
92
+ else:
93
+ if self.is_kept_char(ch):
94
+ current_latin.append(ch)
95
+ else:
96
+ flush_latin()
97
+
98
+ flush_latin()
99
+
100
+ return tokens
101
+
102
+ def tokenize_japanese(self, text: str) -> List[str]:
103
+ words = nagisa.tagging(text).words
104
+ tokens: List[str] = []
105
+ for w in words:
106
+ cleaned = self.clean_token(w)
107
+ if cleaned:
108
+ tokens.append(cleaned)
109
+ return tokens
110
+
111
+ def tokenize_korean(self, ko_tokenizer, text: str) -> List[str]:
112
+ raw_tokens = ko_tokenizer.tokenize(text)
113
+ tokens: List[str] = []
114
+ for w in raw_tokens:
115
+ w_clean = self.clean_token(w)
116
+ if w_clean:
117
+ tokens.append(w_clean)
118
+ return tokens
119
+
120
+ def split_segment_with_chinese(self, seg: str) -> List[str]:
121
+ tokens: List[str] = []
122
+ buf: List[str] = []
123
+
124
+ def flush_buf():
125
+ nonlocal buf
126
+ if buf:
127
+ tokens.append("".join(buf))
128
+ buf = []
129
+
130
+ for ch in seg:
131
+ if self.is_cjk_char(ch):
132
+ flush_buf()
133
+ tokens.append(ch)
134
+ else:
135
+ buf.append(ch)
136
+
137
+ flush_buf()
138
+ return tokens
139
+
140
+ def tokenize_space_lang(self, text: str) -> List[str]:
141
+ tokens: List[str] = []
142
+ for seg in text.split():
143
+ cleaned = self.clean_token(seg)
144
+ if cleaned:
145
+ tokens.extend(self.split_segment_with_chinese(cleaned))
146
+ return tokens
147
+
148
+ def fix_timestamp(self, data) -> List[int]:
149
+ data = data.tolist()
150
+ n = len(data)
151
+
152
+ dp = [1] * n
153
+ parent = [-1] * n
154
+
155
+ for i in range(1, n):
156
+ for j in range(i):
157
+ if data[j] <= data[i] and dp[j] + 1 > dp[i]:
158
+ dp[i] = dp[j] + 1
159
+ parent[i] = j
160
+
161
+ max_length = max(dp)
162
+ max_idx = dp.index(max_length)
163
+
164
+ lis_indices = []
165
+ idx = max_idx
166
+ while idx != -1:
167
+ lis_indices.append(idx)
168
+ idx = parent[idx]
169
+ lis_indices.reverse()
170
+
171
+ is_normal = [False] * n
172
+ for idx in lis_indices:
173
+ is_normal[idx] = True
174
+
175
+ result = data.copy()
176
+ i = 0
177
+
178
+ while i < n:
179
+ if not is_normal[i]:
180
+ j = i
181
+ while j < n and not is_normal[j]:
182
+ j += 1
183
+
184
+ anomaly_count = j - i
185
+
186
+ if anomaly_count <= 2:
187
+ left_val = None
188
+ for k in range(i - 1, -1, -1):
189
+ if is_normal[k]:
190
+ left_val = result[k]
191
+ break
192
+
193
+ right_val = None
194
+ for k in range(j, n):
195
+ if is_normal[k]:
196
+ right_val = result[k]
197
+ break
198
+
199
+ for k in range(i, j):
200
+ if left_val is None:
201
+ result[k] = right_val
202
+ elif right_val is None:
203
+ result[k] = left_val
204
+ else:
205
+ result[k] = left_val if (k - (i - 1)) <= ((j) - k) else right_val
206
+
207
+ else:
208
+ left_val = None
209
+ for k in range(i - 1, -1, -1):
210
+ if is_normal[k]:
211
+ left_val = result[k]
212
+ break
213
+
214
+ right_val = None
215
+ for k in range(j, n):
216
+ if is_normal[k]:
217
+ right_val = result[k]
218
+ break
219
+
220
+ if left_val is not None and right_val is not None:
221
+ step = (right_val - left_val) / (anomaly_count + 1)
222
+ for k in range(i, j):
223
+ result[k] = left_val + step * (k - i + 1)
224
+ elif left_val is not None:
225
+ for k in range(i, j):
226
+ result[k] = left_val
227
+ elif right_val is not None:
228
+ for k in range(i, j):
229
+ result[k] = right_val
230
+
231
+ i = j
232
+ else:
233
+ i += 1
234
+
235
+ return [int(res) for res in result]
236
+
237
+ def encode_timestamp(self, text: str, language: str) -> List[str]:
238
+ language = language.lower()
239
+
240
+ if language.lower() == "japanese":
241
+ word_list = self.tokenize_japanese(text)
242
+ elif language.lower() == "korean":
243
+ if self.ko_tokenizer is None:
244
+ from soynlp.tokenizer import LTokenizer
245
+ self.ko_tokenizer = LTokenizer(scores=self.ko_score)
246
+ word_list = self.tokenize_korean(self.ko_tokenizer, text)
247
+ else:
248
+ word_list = self.tokenize_space_lang(text)
249
+
250
+ input_text = "<timestamp><timestamp>".join(word_list) + "<timestamp><timestamp>"
251
+ input_text = "<|audio_start|><|audio_pad|><|audio_end|>" + input_text
252
+
253
+ return word_list, input_text
254
+
255
+ def parse_timestamp(self, word_list, timestamp):
256
+ timestamp_output = []
257
+
258
+ timestamp_fixed = self.fix_timestamp(timestamp)
259
+ for i, word in enumerate(word_list):
260
+ start_time = timestamp_fixed[i * 2]
261
+ end_time = timestamp_fixed[i * 2 + 1]
262
+ timestamp_output.append({
263
+ "text": word,
264
+ "start_time": start_time,
265
+ "end_time": end_time
266
+ })
267
+
268
+ return timestamp_output
269
+
270
+
271
+ @dataclass(frozen=True)
272
+ class ForcedAlignItem:
273
+ """
274
+ One aligned item span.
275
+
276
+ Attributes:
277
+ text (str):
278
+ The aligned unit (cjk character or word) produced by the forced aligner processor.
279
+ start_time (float):
280
+ Start time in seconds.
281
+ end_time (float):
282
+ End time in seconds.
283
+ """
284
+ text: str
285
+ start_time: int
286
+ end_time: int
287
+
288
+
289
+ @dataclass(frozen=True)
290
+ class ForcedAlignResult:
291
+ """
292
+ Forced alignment output for one sample.
293
+
294
+ Attributes:
295
+ items (List[ForcedAlignItem]):
296
+ Aligned token spans.
297
+ """
298
+ items: List[ForcedAlignItem]
299
+
300
+ def __iter__(self):
301
+ return iter(self.items)
302
+
303
+ def __len__(self):
304
+ return len(self.items)
305
+
306
+ def __getitem__(self, idx: int) -> ForcedAlignItem:
307
+ return self.items[idx]
308
+
309
+
310
+ class Qwen3ForcedAligner:
311
+ """
312
+ A HuggingFace-style wrapper for Qwen3-ForcedAligner model inference.
313
+
314
+ This wrapper provides:
315
+ - `from_pretrained()` initialization via HuggingFace AutoModel/AutoProcessor
316
+ - audio input normalization (path/URL/base64/(np.ndarray, sr))
317
+ - batch and single-sample forced alignment
318
+ - structured output with attribute access (`.text`, `.start_time`, `.end_time`)
319
+ """
320
+
321
+ def __init__(
322
+ self,
323
+ model: Qwen3ASRForConditionalGeneration,
324
+ processor: Qwen3ASRProcessor,
325
+ aligner_processor: Qwen3ForceAlignProcessor,
326
+ ):
327
+ self.model = model
328
+ self.processor = processor
329
+ self.aligner_processor = aligner_processor
330
+
331
+ self.device = getattr(model, "device", None)
332
+ if self.device is None:
333
+ try:
334
+ self.device = next(model.parameters()).device
335
+ except StopIteration:
336
+ self.device = torch.device("cpu")
337
+
338
+ self.timestamp_token_id = int(model.config.timestamp_token_id)
339
+ self.timestamp_segment_time = float(model.config.timestamp_segment_time)
340
+
341
+ @classmethod
342
+ def from_pretrained(
343
+ cls,
344
+ pretrained_model_name_or_path: str,
345
+ **kwargs,
346
+ ) -> "Qwen3ForcedAligner":
347
+ """
348
+ Load Qwen3-ForcedAligner model and initialize processors.
349
+
350
+ This method:
351
+ 1) Registers config/model/processor for HF auto classes.
352
+ 2) Loads the model using `AutoModel.from_pretrained(...)`.
353
+ 3) Initializes:
354
+ - HF processor (`AutoProcessor.from_pretrained(...)`)
355
+ - forced alignment text processor (`Qwen3ForceAlignProcessor()`)
356
+
357
+ Args:
358
+ pretrained_model_name_or_path (str):
359
+ HuggingFace repo id or local directory.
360
+ **kwargs:
361
+ Forwarded to `AutoModel.from_pretrained(...)`.
362
+ Typical examples: device_map="cuda:0", dtype=torch.bfloat16.
363
+
364
+ Returns:
365
+ Qwen3ForcedAligner:
366
+ Initialized wrapper instance.
367
+ """
368
+ AutoConfig.register("qwen3_asr", Qwen3ASRConfig)
369
+ AutoModel.register(Qwen3ASRConfig, Qwen3ASRForConditionalGeneration)
370
+ AutoProcessor.register(Qwen3ASRConfig, Qwen3ASRProcessor)
371
+
372
+ model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
373
+ if not isinstance(model, Qwen3ASRForConditionalGeneration):
374
+ raise TypeError(
375
+ f"AutoModel returned {type(model)}, expected Qwen3ASRForConditionalGeneration."
376
+ )
377
+
378
+ processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, fix_mistral_regex=True)
379
+ aligner_processor = Qwen3ForceAlignProcessor()
380
+
381
+ return cls(model=model, processor=processor, aligner_processor=aligner_processor)
382
+
383
+ def _to_structured_items(self, timestamp_output: List[Dict[str, Any]]) -> ForcedAlignResult:
384
+ items: List[ForcedAlignItem] = []
385
+ for it in timestamp_output:
386
+ items.append(
387
+ ForcedAlignItem(
388
+ text=str(it.get("text", "")),
389
+ start_time=float(it.get("start_time", 0)),
390
+ end_time=float(it.get("end_time", 0)),
391
+ )
392
+ )
393
+ return ForcedAlignResult(items=items)
394
+
395
+ @torch.inference_mode()
396
+ def align(
397
+ self,
398
+ audio: Union[AudioLike, List[AudioLike]],
399
+ text: Union[str, List[str]],
400
+ language: Union[str, List[str]],
401
+ ) -> List[ForcedAlignResult]:
402
+ """
403
+ Run forced alignment for batch or single sample.
404
+
405
+ Args:
406
+ audio:
407
+ Audio input(s). Each item supports:
408
+ - local path / https URL / base64 string
409
+ - (np.ndarray, sr)
410
+ All audios will be converted into mono 16k float32 arrays in [-1, 1].
411
+ text:
412
+ Transcript(s) for alignment.
413
+ language:
414
+ Language(s) for each sample (e.g., "Chinese", "English").
415
+
416
+ Returns:
417
+ List[ForcedAlignResult]:
418
+ One result per sample. Each result contains `items`, and each token can be accessed via
419
+ `.text`, `.start_time`, `.end_time`.
420
+ """
421
+ texts = ensure_list(text)
422
+ languages = ensure_list(language)
423
+ audios = normalize_audios(audio)
424
+
425
+ if len(languages) == 1 and len(audios) > 1:
426
+ languages = languages * len(audios)
427
+
428
+ if not (len(audios) == len(texts) == len(languages)):
429
+ raise ValueError(
430
+ f"Batch size mismatch: audio={len(audios)}, text={len(texts)}, language={len(languages)}"
431
+ )
432
+
433
+ word_lists = []
434
+ aligner_input_texts = []
435
+ for t, lang in zip(texts, languages):
436
+ word_list, aligner_input_text = self.aligner_processor.encode_timestamp(t, lang)
437
+ word_lists.append(word_list)
438
+ aligner_input_texts.append(aligner_input_text)
439
+
440
+ inputs = self.processor(
441
+ text=aligner_input_texts,
442
+ audio=audios,
443
+ return_tensors="pt",
444
+ padding=True,
445
+ )
446
+ inputs = inputs.to(self.model.device).to(self.model.dtype)
447
+
448
+ logits = self.model.thinker(**inputs).logits
449
+ output_ids = logits.argmax(dim=-1)
450
+
451
+ results: List[ForcedAlignResult] = []
452
+ for input_id, output_id, word_list in zip(inputs["input_ids"], output_ids, word_lists):
453
+ masked_output_id = output_id[input_id == self.timestamp_token_id]
454
+ timestamp_ms = (masked_output_id * self.timestamp_segment_time).to("cpu").numpy()
455
+ timestamp_output = self.aligner_processor.parse_timestamp(word_list, timestamp_ms)
456
+ for it in timestamp_output:
457
+ it['start_time'] = round(it['start_time'] / 1000.0, 3)
458
+ it['end_time'] = round(it['end_time'] / 1000.0, 3)
459
+ results.append(self._to_structured_items(timestamp_output))
460
+
461
+ return results
462
+
463
+ def get_supported_languages(self) -> Optional[List[str]]:
464
+ """
465
+ List supported language names for the current model.
466
+
467
+ This is a thin wrapper around `self.model.get_support_languages()`.
468
+ If the underlying model does not expose language constraints (returns None),
469
+ this method also returns None.
470
+
471
+ Returns:
472
+ Optional[List[str]]:
473
+ - A sorted list of supported language names (lowercased), if available.
474
+ - None if the model does not provide supported languages.
475
+ """
476
+ fn = getattr(self.model, "get_support_languages", None)
477
+ if not callable(fn):
478
+ return None
479
+
480
+ langs = fn()
481
+ if langs is None:
482
+ return None
483
+
484
+ return sorted({str(x).lower() for x in langs})
qwen_asr/inference/utils.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Alibaba Qwen team.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import base64
17
+ import io
18
+ import urllib.request
19
+ from dataclasses import dataclass
20
+ from typing import Any, Iterable, List, Optional, Tuple, Union
21
+ from urllib.parse import urlparse
22
+
23
+ import librosa
24
+ import numpy as np
25
+ import soundfile as sf
26
+
27
+ AudioLike = Union[
28
+ str, # wav path / URL / base64
29
+ Tuple[np.ndarray, int], # (waveform, sr)
30
+ ]
31
+ MaybeList = Union[Any, List[Any]]
32
+
33
+ SAMPLE_RATE = 16000
34
+ MAX_ASR_INPUT_SECONDS = 1200
35
+ MAX_FORCE_ALIGN_INPUT_SECONDS = 180
36
+ MIN_ASR_INPUT_SECONDS = 0.5
37
+ SUPPORTED_LANGUAGES: List[str] = [
38
+ "Chinese",
39
+ "English",
40
+ "Cantonese",
41
+ "Arabic",
42
+ "German",
43
+ "French",
44
+ "Spanish",
45
+ "Portuguese",
46
+ "Indonesian",
47
+ "Italian",
48
+ "Korean",
49
+ "Russian",
50
+ "Thai",
51
+ "Vietnamese",
52
+ "Japanese",
53
+ "Turkish",
54
+ "Hindi",
55
+ "Malay",
56
+ "Dutch",
57
+ "Swedish",
58
+ "Danish",
59
+ "Finnish",
60
+ "Polish",
61
+ "Czech",
62
+ "Filipino",
63
+ "Persian",
64
+ "Greek",
65
+ "Romanian",
66
+ "Hungarian",
67
+ "Macedonian"
68
+ ]
69
+ _ASR_TEXT_TAG = "<asr_text>"
70
+ _LANG_PREFIX = "language "
71
+
72
+
73
+ def normalize_language_name(language: str) -> str:
74
+ """
75
+ Normalize language name to the canonical format used by Qwen3-ASR:
76
+ first letter uppercase, the rest lowercase (e.g., 'cHINese' -> 'Chinese').
77
+
78
+ Args:
79
+ language (str): Input language name.
80
+
81
+ Returns:
82
+ str: Normalized language name.
83
+
84
+ Raises:
85
+ ValueError: If language is empty.
86
+ """
87
+ if language is None:
88
+ raise ValueError("language is None")
89
+ s = str(language).strip()
90
+ if not s:
91
+ raise ValueError("language is empty")
92
+ return s[:1].upper() + s[1:].lower()
93
+
94
+
95
+ def validate_language(language: str) -> None:
96
+ """
97
+ Validate the language is supported.
98
+
99
+ Args:
100
+ language (str): Canonical language name.
101
+
102
+ Raises:
103
+ ValueError: If unsupported.
104
+ """
105
+ if language not in SUPPORTED_LANGUAGES:
106
+ raise ValueError(f"Unsupported language: {language}. Supported: {SUPPORTED_LANGUAGES}")
107
+
108
+
109
+ def ensure_list(x: MaybeList) -> List[Any]:
110
+ return x if isinstance(x, list) else [x]
111
+
112
+
113
+ def is_url(s: str) -> bool:
114
+ try:
115
+ u = urlparse(s)
116
+ return u.scheme in ("http", "https") and bool(u.netloc)
117
+ except Exception:
118
+ return False
119
+
120
+
121
+ def is_probably_base64(s: str) -> bool:
122
+ if s.startswith("data:audio"):
123
+ return True
124
+ if ("/" not in s and "\\" not in s) and len(s) > 256:
125
+ return True
126
+ return False
127
+
128
+
129
+ def decode_base64_bytes(b64: str) -> bytes:
130
+ if "," in b64 and b64.strip().startswith("data:"):
131
+ b64 = b64.split(",", 1)[1]
132
+ return base64.b64decode(b64)
133
+
134
+
135
+ def load_audio_any(x: str) -> Tuple[np.ndarray, int]:
136
+ if is_url(x):
137
+ with urllib.request.urlopen(x) as resp:
138
+ audio_bytes = resp.read()
139
+ with io.BytesIO(audio_bytes) as f:
140
+ audio, sr = sf.read(f, dtype="float32", always_2d=False)
141
+ elif is_probably_base64(x):
142
+ audio_bytes = decode_base64_bytes(x)
143
+ with io.BytesIO(audio_bytes) as f:
144
+ audio, sr = sf.read(f, dtype="float32", always_2d=False)
145
+ else:
146
+ audio, sr = librosa.load(x, sr=None, mono=False)
147
+
148
+ audio = np.asarray(audio, dtype=np.float32)
149
+ sr = int(sr)
150
+ return audio, sr
151
+
152
+
153
+ def to_mono(audio: np.ndarray) -> np.ndarray:
154
+ if audio.ndim == 1:
155
+ return audio
156
+ # soundfile can return shape (T, C); some pipelines use (C, T)
157
+ if audio.ndim == 2:
158
+ if audio.shape[0] <= 8 and audio.shape[1] > audio.shape[0]:
159
+ audio = audio.T
160
+ return np.mean(audio, axis=-1).astype(np.float32)
161
+ raise ValueError(f"Unsupported audio ndim={audio.ndim}")
162
+
163
+
164
+ def float_range_normalize(audio: np.ndarray) -> np.ndarray:
165
+ audio = audio.astype(np.float32)
166
+ if audio.size == 0:
167
+ return audio
168
+ peak = float(np.max(np.abs(audio)))
169
+ if peak == 0.0:
170
+ return audio
171
+ # If decoded audio is int-like scaled or out-of-range, normalize conservatively.
172
+ if peak > 1.0:
173
+ audio = audio / peak
174
+ audio = np.clip(audio, -1.0, 1.0)
175
+ return audio
176
+
177
+
178
+ def normalize_audio_input(a: AudioLike) -> np.ndarray:
179
+ """
180
+ Normalize one audio input to mono 16k float32 waveform in [-1, 1].
181
+
182
+ Supported inputs:
183
+ - str: local file path / https URL / base64 audio string
184
+ - (np.ndarray, sr): waveform and sampling rate
185
+
186
+ Returns:
187
+ np.ndarray:
188
+ Mono 16k float32 waveform in [-1, 1].
189
+ """
190
+ if isinstance(a, str):
191
+ audio, sr = load_audio_any(a)
192
+ elif isinstance(a, tuple) and len(a) == 2 and isinstance(a[0], np.ndarray):
193
+ audio, sr = a[0], int(a[1])
194
+ else:
195
+ raise TypeError(f"Unsupported audio input type: {type(a)}")
196
+
197
+ audio = to_mono(np.asarray(audio))
198
+ if sr != SAMPLE_RATE:
199
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE).astype(np.float32)
200
+ audio = float_range_normalize(audio)
201
+ return audio
202
+
203
+
204
+ def normalize_audios(audios: Union[AudioLike, List[AudioLike]]) -> List[np.ndarray]:
205
+ items = ensure_list(audios)
206
+ return [normalize_audio_input(a) for a in items]
207
+
208
+
209
+ def chunk_list(xs: List[Any], chunk_size: int) -> Iterable[List[Any]]:
210
+ """
211
+ Yield chunks of a list.
212
+
213
+ Args:
214
+ xs (List[Any]): Input list.
215
+ chunk_size (int): Chunk size.
216
+
217
+ Yields:
218
+ List[Any]: Slices of xs.
219
+ """
220
+ if chunk_size <= 0:
221
+ yield xs
222
+ return
223
+ for i in range(0, len(xs), chunk_size):
224
+ yield xs[i : i + chunk_size]
225
+
226
+
227
+ @dataclass(frozen=True)
228
+ class AudioChunk:
229
+ """
230
+ One chunk cut from an original audio.
231
+
232
+ Attributes:
233
+ orig_index: Index of the original sample in the input batch.
234
+ chunk_index: Index of this chunk within the original sample.
235
+ wav: Mono float32 waveform.
236
+ sr: Sampling rate.
237
+ offset_sec: Start offset of this chunk in the original audio, in seconds.
238
+ """
239
+ orig_index: int
240
+ chunk_index: int
241
+ wav: np.ndarray
242
+ sr: int
243
+ offset_sec: float
244
+
245
+
246
+ def split_audio_into_chunks(
247
+ wav: np.ndarray,
248
+ sr: int,
249
+ max_chunk_sec: float,
250
+ search_expand_sec: float = 5.0,
251
+ min_window_ms: float = 100.0,
252
+ ) -> List[Tuple[np.ndarray, float]]:
253
+ """
254
+ Split a long audio into chunks close to max_chunk_sec, using a low-energy boundary.
255
+
256
+ This implementation guarantees:
257
+ - Concatenating all returned chunks reproduces the original audio exactly
258
+ (total number of samples is identical, no overlaps, no gaps).
259
+
260
+ Args:
261
+ wav: Mono waveform float32.
262
+ sr: Sampling rate.
263
+ max_chunk_sec: Target max chunk duration in seconds.
264
+ search_expand_sec: Boundary search half-window in seconds.
265
+ min_window_ms: Sliding window in milliseconds for energy estimation.
266
+
267
+ Returns:
268
+ List[Tuple[np.ndarray, float]]: List of (chunk_wav, offset_sec).
269
+ """
270
+ wav = np.asarray(wav, dtype=np.float32)
271
+ if wav.ndim > 1:
272
+ wav = np.mean(wav, axis=-1).astype(np.float32)
273
+
274
+ total_len = int(wav.shape[0])
275
+ total_sec = total_len / float(sr)
276
+ if total_sec <= max_chunk_sec:
277
+ return [(wav, 0.0)]
278
+
279
+ max_len = int(max_chunk_sec * sr)
280
+ expand = int(search_expand_sec * sr)
281
+ win = max(4, int((min_window_ms / 1000.0) * sr))
282
+
283
+ chunks: List[Tuple[np.ndarray, float]] = []
284
+
285
+ start = 0
286
+ offset_sec = 0.0
287
+
288
+ while (total_len - start) > max_len:
289
+ cut = start + max_len
290
+
291
+ left = max(start, cut - expand)
292
+ right = min(total_len, cut + expand)
293
+
294
+ if right - left <= win:
295
+ boundary = cut
296
+ else:
297
+ seg = wav[left:right]
298
+ seg_abs = np.abs(seg)
299
+
300
+ window_sums = np.convolve(seg_abs, np.ones(win, dtype=np.float32), mode="valid")
301
+
302
+ min_pos = int(np.argmin(window_sums))
303
+
304
+ wstart = min_pos
305
+ wend = min_pos + win
306
+ local = seg_abs[wstart:wend]
307
+ inner = int(np.argmin(local))
308
+ boundary = left + wstart + inner
309
+
310
+ boundary = int(max(boundary, start + 1))
311
+ boundary = int(min(boundary, total_len))
312
+
313
+ chunk = wav[start:boundary]
314
+ chunks.append((chunk, offset_sec))
315
+
316
+ offset_sec += (boundary - start) / float(sr)
317
+ start = boundary
318
+
319
+ tail = wav[start:total_len]
320
+ chunks.append((tail, offset_sec))
321
+
322
+ # Pad too-short chunks to at least MIN_ASR_INPUT_SECONDS (zero-padding at tail)
323
+ min_len = int(MIN_ASR_INPUT_SECONDS * sr)
324
+ padded: List[Tuple[np.ndarray, float]] = []
325
+ for c, off in chunks:
326
+ if c.shape[0] < min_len:
327
+ pad = min_len - int(c.shape[0])
328
+ c = np.pad(c, (0, pad), mode="constant", constant_values=0.0).astype(np.float32)
329
+ padded.append((c, off))
330
+ chunks = padded
331
+
332
+ return chunks
333
+
334
+
335
+ def detect_and_fix_repetitions(text, threshold=20):
336
+ def fix_char_repeats(s, thresh):
337
+ res = []
338
+ i = 0
339
+ n = len(s)
340
+ while i < n:
341
+ count = 1
342
+ while i + count < n and s[i + count] == s[i]:
343
+ count += 1
344
+
345
+ if count > thresh:
346
+ res.append(s[i])
347
+ i += count
348
+ else:
349
+ res.append(s[i:i+count])
350
+ i += count
351
+ return ''.join(res)
352
+
353
+ def fix_pattern_repeats(s, thresh, max_len=20):
354
+ n = len(s)
355
+ min_repeat_chars = thresh * 2
356
+ if n < min_repeat_chars:
357
+ return s
358
+
359
+ i = 0
360
+ result = []
361
+ while i <= n - min_repeat_chars:
362
+ found = False
363
+ for k in range(1, max_len + 1):
364
+ if i + k * thresh > n:
365
+ break
366
+
367
+ pattern = s[i:i+k]
368
+ valid = True
369
+ for rep in range(1, thresh):
370
+ start_idx = i + rep * k
371
+ if s[start_idx:start_idx+k] != pattern:
372
+ valid = False
373
+ break
374
+
375
+ if valid:
376
+ total_rep = thresh
377
+ end_index = i + thresh * k
378
+ while end_index + k <= n and s[end_index:end_index+k] == pattern:
379
+ total_rep += 1
380
+ end_index += k
381
+ result.append(pattern)
382
+ result.append(fix_pattern_repeats(s[end_index:], thresh, max_len))
383
+ i = n
384
+ found = True
385
+ break
386
+
387
+ if found:
388
+ break
389
+ else:
390
+ result.append(s[i])
391
+ i += 1
392
+
393
+ if not found:
394
+ result.append(s[i:])
395
+ return ''.join(result)
396
+
397
+ text_raw = text
398
+ text = fix_char_repeats(text_raw, threshold)
399
+ text = fix_pattern_repeats(text, threshold)
400
+ return text
401
+
402
+
403
+ def parse_asr_output(
404
+ raw: str,
405
+ user_language: Optional[str] = None,
406
+ ) -> Tuple[str, str]:
407
+ """
408
+ Parse Qwen3-ASR raw output into (language, text).
409
+
410
+ Cases:
411
+ - With tag: "language Chinese<asr_text>...."
412
+ - With newlines: "language Chinese\\n...\\n<asr_text>...."
413
+ - No tag: treat whole string as text.
414
+ - "language None<asr_text>": treat as empty audio -> ("", "")
415
+
416
+ If user_language is provided, language is forced to user_language and raw is treated as text-only
417
+ (the model is expected to output plain transcription without metadata).
418
+
419
+ Args:
420
+ raw: Raw decoded string.
421
+ user_language: Canonical language name if user forced language.
422
+
423
+ Returns:
424
+ Tuple[str, str]: (language, text)
425
+ """
426
+ if raw is None:
427
+ return "", ""
428
+ s = str(raw).strip()
429
+ if not s:
430
+ return "", ""
431
+
432
+ s = detect_and_fix_repetitions(s)
433
+
434
+ if user_language:
435
+ # user explicitly forced language => model output is treated as pure text
436
+ return user_language, s
437
+
438
+ meta_part = s
439
+ text_part = ""
440
+ has_tag = _ASR_TEXT_TAG in s
441
+ if has_tag:
442
+ meta_part, text_part = s.split(_ASR_TEXT_TAG, 1)
443
+ else:
444
+ # no tag => pure text
445
+ return "", s.strip()
446
+
447
+ meta_lower = meta_part.lower()
448
+
449
+ # empty audio heuristic
450
+ if "language none" in meta_lower:
451
+ t = text_part.strip()
452
+ if not t:
453
+ return "", ""
454
+ # if model still returned something, keep it but language unknown
455
+ return "", t
456
+
457
+ # extract "language xxx" from meta
458
+ lang = ""
459
+ for line in meta_part.splitlines():
460
+ line = line.strip()
461
+ if not line:
462
+ continue
463
+ low = line.lower()
464
+ if low.startswith(_LANG_PREFIX):
465
+ val = line[len(_LANG_PREFIX):].strip()
466
+ if val:
467
+ lang = normalize_language_name(val)
468
+ break
469
+
470
+ return lang, text_part.strip()
471
+
472
+
473
+ def merge_languages(langs: List[str]) -> str:
474
+ """
475
+ Merge per-chunk languages into a compact comma-separated string,
476
+ keeping order and removing consecutive duplicates and empty entries.
477
+
478
+ Example:
479
+ ["Chinese", "English", "English"] -> "Chinese,English"
480
+
481
+ Args:
482
+ langs: List of canonical language names.
483
+
484
+ Returns:
485
+ str: Merged language string.
486
+ """
487
+ out: List[str] = []
488
+ prev = None
489
+ for x in langs:
490
+ x = (x or "").strip()
491
+ if not x:
492
+ continue
493
+ if x == prev:
494
+ continue
495
+ out.append(x)
496
+ prev = x
497
+ return ",".join(out)
requirements.txt CHANGED
@@ -12,4 +12,3 @@ sox
12
  scipy
13
  gradio>=4.0.0
14
  spaces
15
- qwen-asr
 
12
  scipy
13
  gradio>=4.0.0
14
  spaces