HaoranWei commited on
Commit
db8cee3
·
verified ·
1 Parent(s): b5ac633

initial commit

Browse files
.ipynb_checkpoints/README-checkpoint.md ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ pipeline_tag: image-text-to-text
3
+ language:
4
+ - multilingual
5
+ tags:
6
+ - deepseek
7
+ - vision-language
8
+ - ocr
9
+ - custom_code
10
+ license: mit
11
+ library_name: transformers
12
+ ---
13
+ <div align="center">
14
+ <img src="https://github.com/deepseek-ai/DeepSeek-V2/blob/main/figures/logo.svg?raw=true" width="60%" alt="DeepSeek AI" />
15
+ </div>
16
+ <hr>
17
+ <div align="center">
18
+ <a href="https://www.deepseek.com/" target="_blank">
19
+ <img alt="Homepage" src="https://github.com/deepseek-ai/DeepSeek-V2/blob/main/figures/badge.svg?raw=true" />
20
+ </a>
21
+ <a href="https://huggingface.co/deepseek-ai/DeepSeek-OCR-2" target="_blank">
22
+ <img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-DeepSeek%20AI-ffc107?color=ffc107&logoColor=white" />
23
+ </a>
24
+
25
+ </div>
26
+
27
+ <div align="center">
28
+
29
+ <a href="https://discord.gg/Tc7c45Zzu5" target="_blank">
30
+ <img alt="Discord" src="https://img.shields.io/badge/Discord-DeepSeek%20AI-7289da?logo=discord&logoColor=white&color=7289da" />
31
+ </a>
32
+ <a href="https://twitter.com/deepseek_ai" target="_blank">
33
+ <img alt="Twitter Follow" src="https://img.shields.io/badge/Twitter-deepseek_ai-white?logo=x&logoColor=white" />
34
+ </a>
35
+
36
+ </div>
37
+
38
+
39
+
40
+ <p align="center">
41
+ <a href="https://github.com/deepseek-ai/DeepSeek-OCR-2"><b>🌟 Github</b></a> |
42
+ <a href="https://huggingface.co/deepseek-ai/DeepSeek-OCR-2"><b>📥 Model Download</b></a> |
43
+ <a href="https://github.com/deepseek-ai/DeepSeek-OCR-2/blob/main/DeepSeek_OCR2_paper.pdf"><b>📄 Paper Link</b></a> |
44
+ <a href="https://github.com/deepseek-ai/DeepSeek-OCR-2/blob/main/DeepSeek_OCR2_paper.pdf"><b>📄 Arxiv Paper Link</b></a> |
45
+ </p>
46
+ <h2>
47
+ <p align="center">
48
+ <a href="">DeepSeek-OCR 2: Visual Causal Flow</a>
49
+ </p>
50
+ </h2>
51
+ <p align="center">
52
+ <img src="assets/fig1.png" style="width: 900px" align=center>
53
+ </p>
54
+ <p align="center">
55
+ <a href="">Explore more human-like visual encoding.</a>
56
+ </p>
57
+
58
+ ## Usage
59
+
60
+ Inference using Huggingface transformers on NVIDIA GPUs. Requirements tested on python 3.12.9 + CUDA11.8:
61
+
62
+ ```
63
+ torch==2.6.0
64
+ transformers==4.46.3
65
+ tokenizers==0.20.3
66
+ einops
67
+ addict
68
+ easydict
69
+ pip install flash-attn==2.7.3 --no-build-isolation
70
+ ```
71
+
72
+ ```python
73
+ from transformers import AutoModel, AutoTokenizer
74
+ import torch
75
+ import os
76
+ os.environ["CUDA_VISIBLE_DEVICES"] = '0'
77
+ model_name = 'deepseek-ai/DeepSeek-OCR-2'
78
+
79
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
80
+ model = AutoModel.from_pretrained(model_name, _attn_implementation='flash_attention_2', trust_remote_code=True, use_safetensors=True)
81
+ model = model.eval().cuda().to(torch.bfloat16)
82
+
83
+ # prompt = "<image>\nFree OCR. "
84
+ prompt = "<image>\n<|grounding|>Convert the document to markdown. "
85
+ image_file = 'your_image.jpg'
86
+ output_path = 'your/output/dir'
87
+
88
+
89
+ res = model.infer(tokenizer, prompt=prompt, image_file=image_file, output_path = output_path, base_size = 1024, image_size = 768, crop_mode=True, save_results = True)
90
+ ```
91
+
92
+ ## vLLM
93
+
94
+
95
+ Refer to [🌟GitHub](https://github.com/deepseek-ai/DeepSeek-OCR-2/) for guidance on model inference acceleration and PDF processing, etc.<!-- -->
96
+
97
+ ## Support-Modes
98
+ - Dynamic resolution
99
+ - Default: (0-6)×768×768 + 1×1024×1024 — (0-6)×144 + 256 visual tokens ✅
100
+
101
+ ## Prompts examples
102
+ ```python
103
+ # document: <image>\n<|grounding|>Convert the document to markdown.
104
+ # other image: <image>\n<|grounding|>OCR this image.
105
+ # without layouts: <image>\nFree OCR.
106
+ # figures in document: <image>\nParse the figure.
107
+ # general: <image>\nDescribe this image in detail.
108
+ # rec: <image>\nLocate <|ref|>xxxx<|/ref|> in the image.
109
+ ```
110
+
111
+
112
+ ## Acknowledgement
113
+
114
+ We would like to thank [DeepSeek-OCR](https://github.com/deepseek-ai/DeepSeek-OCR/), [Vary](https://github.com/Ucas-HaoranWei/Vary/), [GOT-OCR2.0](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/), [MinerU](https://github.com/opendatalab/MinerU), [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR) for their valuable models and ideas.
115
+
116
+ We also appreciate the benchmark [OminiDocBench](https://github.com/opendatalab/OmniDocBench).
117
+
118
+
119
+ ## Citation
120
+
121
+ ```bibtex
122
+ coming soon~
.ipynb_checkpoints/deepencoderv2-checkpoint.py ADDED
@@ -0,0 +1,1015 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import copy
5
+
6
+
7
+ from typing import Optional, Tuple
8
+
9
+ # from megatron.model import LayerNorm
10
+
11
+ import transformers
12
+
13
+
14
+ from typing import Optional, Tuple, Type
15
+ from functools import partial
16
+
17
+
18
+
19
+ class MlpProjector(nn.Module):
20
+
21
+ def __init__(self, cfg):
22
+
23
+ super().__init__()
24
+
25
+ self.cfg = cfg
26
+
27
+ if cfg.projector_type == "identity":
28
+ modules = nn.Identity()
29
+
30
+ elif cfg.projector_type == "linear":
31
+ modules = nn.Linear(cfg.input_dim, cfg.n_embed)
32
+
33
+ elif cfg.projector_type == "mlp_gelu":
34
+ mlp_depth = cfg.get("depth", 1)
35
+ modules = [nn.Linear(cfg.input_dim, cfg.n_embed)]
36
+ for _ in range(1, mlp_depth):
37
+ modules.append(nn.GELU())
38
+ modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
39
+ modules = nn.Sequential(*modules)
40
+
41
+ elif cfg.projector_type == "normlayer_downsample_mlp_gelu":
42
+ mlp_depth = cfg.get("depth", 1)
43
+ mlp_ratio = cfg.get("mlp_ratio", 1)
44
+ modules = [
45
+ nn.LayerNorm(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio),
46
+ nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)
47
+ ]
48
+ for _ in range(1, mlp_depth - 1):
49
+ modules.append(nn.GELU())
50
+ modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio))
51
+ modules.append(nn.GELU())
52
+ modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
53
+ modules = nn.Sequential(*modules)
54
+
55
+ elif cfg.projector_type == "downsample_mlp_gelu":
56
+ mlp_depth = cfg.get("depth", 1)
57
+ mlp_ratio = cfg.get("mlp_ratio", 1)
58
+ modules = [nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)]
59
+ for _ in range(1, mlp_depth - 1):
60
+ modules.append(nn.GELU())
61
+ modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio))
62
+ modules.append(nn.GELU())
63
+ modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
64
+ modules = nn.Sequential(*modules)
65
+
66
+ elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
67
+ mlp_depth = cfg.get("depth", 1)
68
+ self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
69
+ self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
70
+
71
+ modules = []
72
+ for _ in range(1, mlp_depth):
73
+ modules.append(nn.GELU())
74
+ modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
75
+ modules = nn.Sequential(*modules)
76
+
77
+ elif cfg.projector_type == "hybrid_split_feature_mlp_gelu":
78
+ mlp_depth = cfg.get("depth", 1)
79
+ channel_div = cfg.get("channel_div", 0.5)
80
+ self.high_up_proj = nn.Linear(cfg.input_dim[0], int(cfg.n_embed * channel_div))
81
+ self.low_up_proj = nn.Linear(cfg.input_dim[1], cfg.n_embed - int(cfg.n_embed * channel_div))
82
+
83
+ modules = []
84
+ for _ in range(1, mlp_depth):
85
+ modules.append(nn.GELU())
86
+ modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
87
+ modules = nn.Sequential(*modules)
88
+
89
+ elif cfg.projector_type == "low_high_split_mlp_gelu":
90
+ mlp_depth = cfg.get("depth", 1)
91
+ modules = []
92
+ for _ in range(1, mlp_depth):
93
+ modules.append(nn.GELU())
94
+ modules.append(nn.Linear(cfg.n_embed // 2, cfg.n_embed // 2))
95
+ modules = nn.Sequential(*modules)
96
+ self.high_layers = nn.Sequential(*modules)
97
+ self.low_layers = copy.deepcopy(modules)
98
+
99
+ else:
100
+ raise ValueError(f"Unknown projector type: {cfg.projector_type}")
101
+
102
+ if cfg.get("token_pooling", False):
103
+ self.token_pooling_layer = nn.Linear(cfg.input_dim * 4, cfg.input_dim)
104
+
105
+ if cfg.get("conv_fusion_high_low_features", False):
106
+ self.fusion_layer = nn.Linear(cfg.input_dim, cfg.input_dim)
107
+ self.layers = modules
108
+
109
+ def forward(self, x):
110
+ if self.cfg.get("token_pooling", False):
111
+ batch_size, wxh, channels = x.shape
112
+ w = h = int(wxh**0.5)
113
+ x = x.view(batch_size, w, h, channels)
114
+ x = x.permute(0, 3, 1, 2)
115
+ # import ipdb; ipdb.set_trace()
116
+ patches = x.unfold(2, 2, 2).unfold(3, 2, 2)
117
+ batch_size, channels, h_patches, w_patches, _, _ = patches.size()
118
+ # 在通道维度上拼接
119
+ patches = patches.contiguous().view(batch_size, channels, h_patches * w_patches, -1)
120
+
121
+ # 通过线性层
122
+ patches = patches.permute(0, 2, 1, 3).contiguous()
123
+ patches = patches.view(batch_size, h_patches * w_patches, channels * 4)
124
+
125
+ x = self.token_pooling_layer(patches)
126
+
127
+ if self.cfg.get("conv_fusion_high_low_features", False):
128
+ x = self.fusion_layer(x[:, 0]) + x[:, 1]
129
+
130
+ if self.cfg.projector_type == 'low_high_hybrid_split_mlp_gelu':
131
+ high_x, low_x = x[0], x[1]
132
+ high_x = self.high_up_proj(high_x)
133
+ low_x = self.low_up_proj(low_x)
134
+ x = torch.concat([high_x, low_x], dim=-1)
135
+
136
+ if self.cfg.projector_type == 'hybrid_split_feature_mlp_gelu':
137
+ high_x = x[...,:self.cfg.input_dim[0]]
138
+ low_x = x[...,self.cfg.input_dim[0]:]
139
+ high_x = self.high_up_proj(high_x)
140
+ low_x = self.low_up_proj(low_x)
141
+ x = torch.concat([high_x, low_x], dim=-1)
142
+
143
+ if self.cfg.projector_type == 'low_high_split_mlp_gelu':
144
+ high_x, low_x = x[0], x[1]
145
+ high_x = self.high_layers(high_x)
146
+ low_x = self.low_layers(low_x)
147
+ x = torch.concat([high_x, low_x], dim=-1)
148
+ return x
149
+
150
+ if self.cfg.projector_type == 'downsample_mlp_gelu' or self.cfg.projector_type == 'normlayer_downsample_mlp_gelu':
151
+ bs, hw, input_dim = x.shape
152
+ h = w = int((hw) ** 0.5)
153
+
154
+ """compute padding"""
155
+ if h % self.cfg.downsample_ratio:
156
+ pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio
157
+ else:
158
+ pad = 0
159
+ x = x.reshape(bs, h, w, input_dim)
160
+ if pad > 0:
161
+ x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
162
+
163
+ """4 to 1 concat"""
164
+ x = x.permute(0, 3, 1, 2) # B, C, H, W
165
+ x = F.unfold(x, kernel_size=self.cfg.downsample_ratio, stride=self.cfg.downsample_ratio, padding=0) # B, C*4, HW // 4
166
+ x = x.permute(0, 2, 1)
167
+
168
+ return self.layers(x)
169
+
170
+ @staticmethod
171
+ def get_flops_per_sample(cfg):
172
+ if cfg.projector_type == "linear":
173
+ fwd = 2 * cfg.input_dim * cfg.n_embed
174
+
175
+ elif "mlp_gelu" in cfg.projector_type :
176
+ mlp_depth = cfg.get("depth", 1)
177
+ downsample_ratio = cfg.get("downsample_ratio", 1)
178
+ input_dim = sum(cfg.input_dim) if isinstance(cfg.input_dim, list) else cfg.input_dim
179
+ input_dim = input_dim * downsample_ratio * downsample_ratio
180
+ fwd = 2 * input_dim * cfg.n_embed + (mlp_depth - 1) * 2 * cfg.n_embed * cfg.n_embed
181
+ else:
182
+ fwd = 0
183
+
184
+ return fwd * 3
185
+
186
+
187
+ #===================qwen2================================
188
+
189
+ class CustomQwen2Decoder(nn.Module):
190
+ """
191
+ Qwen2 visual encoder
192
+ non-causal attention + causal attention
193
+ token_type_ids :0=non-causal, 1=causal
194
+ """
195
+
196
+ def __init__(
197
+ self,
198
+ decoder_layer: int = 24,
199
+ max_position_embeddings: int = 131072,
200
+ hidden_dimension: int = 896,
201
+ num_attention_heads: int = 14,
202
+ num_key_value_heads: int = 2,
203
+ intermediate_size: int = 4864,
204
+ vocab_size: int = 151936,
205
+ attn_implementation: str = "sdpa", # ⭐
206
+ rms_norm_eps: float = 1e-06,
207
+ rope_theta: float = 1000000.0,
208
+ attention_dropout: float = 0.0,
209
+ hidden_act: str = "silu",
210
+ initializer_range: float = 0.02,
211
+ ):
212
+ super().__init__()
213
+
214
+ # attn_implementation check
215
+ if attn_implementation == "flash_attention_2":
216
+ raise ValueError(
217
+ "CustomQwen2Decoder do not support flash_attention_2,"
218
+ "new attention mask needs 'sdpa' or 'eager'"
219
+ )
220
+
221
+ # load
222
+ Qwen2Model = getattr(transformers.models.qwen2.modeling_qwen2, 'Qwen2Model')
223
+ Qwen2Config = getattr(transformers, 'Qwen2Config')
224
+
225
+ # config
226
+ config = Qwen2Config(
227
+ hidden_size=hidden_dimension,
228
+ num_hidden_layers=decoder_layer,
229
+ num_attention_heads=num_attention_heads,
230
+ num_key_value_heads=num_key_value_heads,
231
+ intermediate_size=intermediate_size,
232
+ max_position_embeddings=max_position_embeddings,
233
+ vocab_size=vocab_size,
234
+ rms_norm_eps=rms_norm_eps,
235
+ rope_theta=rope_theta,
236
+ attention_dropout=attention_dropout,
237
+ hidden_act=hidden_act,
238
+ initializer_range=initializer_range,
239
+ _attn_implementation=attn_implementation, # ⭐
240
+ )
241
+
242
+ #
243
+ self.model = self._create_custom_model(Qwen2Model, config)
244
+
245
+ del self.model.embed_tokens
246
+
247
+ def _create_custom_model(self, Qwen2Model, config):
248
+ """ Qwen2Model """
249
+
250
+ class CustomQwen2ModelInner(Qwen2Model):
251
+
252
+
253
+ def forward(
254
+ self,
255
+ input_ids=None,
256
+ attention_mask=None,
257
+ position_ids=None,
258
+ past_key_values=None,
259
+ inputs_embeds=None,
260
+ token_type_ids=None, # ⭐
261
+ use_cache=None,
262
+ output_attentions=None,
263
+ output_hidden_states=None,
264
+ return_dict=None,
265
+ cache_position=None,
266
+ ):
267
+ # token_type_ids
268
+ self._current_token_type_ids = token_type_ids
269
+
270
+ outputs = super().forward(
271
+ input_ids=input_ids,
272
+ attention_mask=attention_mask,
273
+ position_ids=position_ids,
274
+ past_key_values=past_key_values,
275
+ inputs_embeds=inputs_embeds,
276
+ use_cache=use_cache,
277
+ output_attentions=output_attentions,
278
+ output_hidden_states=output_hidden_states,
279
+ return_dict=return_dict,
280
+ cache_position=cache_position,
281
+ )
282
+
283
+ return outputs
284
+
285
+ def _update_causal_mask(
286
+ self,
287
+ attention_mask,
288
+ input_tensor,
289
+ cache_position,
290
+ past_key_values,
291
+ output_attentions,
292
+ ):
293
+ dtype, device = input_tensor.dtype, input_tensor.device
294
+ min_dtype = torch.finfo(dtype).min
295
+ batch_size, sequence_length = input_tensor.shape[0], input_tensor.shape[1]
296
+
297
+ token_type_ids = self._current_token_type_ids
298
+
299
+ # attention mask
300
+ causal_mask = self._create_custom_4d_mask(
301
+ sequence_length=sequence_length,
302
+ dtype=dtype,
303
+ device=device,
304
+ batch_size=batch_size,
305
+ token_type_ids=token_type_ids,
306
+ )
307
+
308
+ # padding mask
309
+ if attention_mask is not None and attention_mask.dim() == 2:
310
+ padding_mask = attention_mask[:, None, None, :].to(dtype=dtype)
311
+ padding_mask = (1.0 - padding_mask) * min_dtype
312
+ causal_mask = causal_mask + padding_mask
313
+
314
+ return causal_mask
315
+
316
+ def _create_custom_4d_mask(
317
+ self,
318
+ sequence_length,
319
+ dtype,
320
+ device,
321
+ batch_size,
322
+ token_type_ids,
323
+ ):
324
+ min_dtype = torch.finfo(dtype).min
325
+
326
+ masks = []
327
+ for b in range(batch_size):
328
+ mask = torch.full(
329
+ (sequence_length, sequence_length),
330
+ fill_value=min_dtype,
331
+ dtype=dtype,
332
+ device=device
333
+ )
334
+
335
+ type_ids = token_type_ids[b]
336
+
337
+ image_positions = (type_ids == 0).nonzero(as_tuple=True)[0]
338
+ text_positions = (type_ids == 1).nonzero(as_tuple=True)[0]
339
+
340
+ # non-casual
341
+ if len(image_positions) > 0:
342
+ mask[image_positions[:, None], image_positions] = 0.0
343
+
344
+ # causal
345
+ for i, text_pos in enumerate(text_positions):
346
+ if len(image_positions) > 0:
347
+ mask[text_pos, image_positions] = 0.0
348
+ mask[text_pos, text_positions[:i+1]] = 0.0
349
+
350
+ masks.append(mask)
351
+
352
+ mask = torch.stack(masks, dim=0).unsqueeze(1)
353
+ return mask
354
+
355
+ return CustomQwen2ModelInner(config)
356
+
357
+ def forward(
358
+ self,
359
+ inputs_embeds,
360
+ token_type_ids,
361
+ attention_mask=None,
362
+ **kwargs
363
+ ):
364
+ """
365
+ Args:
366
+ inputs_embeds: [batch_size, seq_len, hidden_dim]
367
+ token_type_ids: [batch_size, seq_len], 0=non-causal, 1=causal
368
+ attention_mask: [batch_size, seq_len], optional
369
+ """
370
+ return self.model(
371
+ inputs_embeds=inputs_embeds,
372
+ token_type_ids=token_type_ids,
373
+ attention_mask=attention_mask,
374
+ **kwargs
375
+ )
376
+
377
+
378
+
379
+
380
+
381
+ # batch_size = 2
382
+ # inputs_embeds = torch.randn(batch_size, 512, 896).cuda()
383
+
384
+ # inputs_embeds = torch.randn(batch_size, 512, 896).cuda()
385
+ # token_type_ids = torch.cat([
386
+ # torch.zeros(batch_size, 256, dtype=torch.long),
387
+ # torch.ones(batch_size, 256, dtype=torch.long),
388
+ # ], dim=1).cuda()
389
+
390
+ # # start = time.time()
391
+ # with torch.no_grad():
392
+ # outputs_sdpa = decoder_sdpa(inputs_embeds, token_type_ids)
393
+ # print(outputs_sdpa[0].shape)
394
+ # print(f"SDPA time: {time.time() - start:.4f}s")
395
+
396
+
397
+
398
+ class Qwen2Decoder2Encoder(nn.Module):
399
+ """
400
+ Decoder based on Multilingual BART
401
+ Set the initial weights and configuration with a pretrained multilingual BART model,
402
+ and modify the detailed configurations as a Nougat decoder
403
+ """
404
+
405
+ def __init__(
406
+ self,
407
+ decoder_layer: int,
408
+ hidden_dimension: int,
409
+ num_attention_heads: int,
410
+ num_key_value_heads: int,
411
+ intermediate_size: int,
412
+ max_query: int,
413
+ ):
414
+ super().__init__()
415
+
416
+ self.model = CustomQwen2Decoder(
417
+ decoder_layer=decoder_layer,
418
+ hidden_dimension=hidden_dimension,
419
+ num_attention_heads=num_attention_heads,
420
+ num_key_value_heads=num_key_value_heads,
421
+ intermediate_size=intermediate_size,
422
+ attn_implementation="sdpa",
423
+ )
424
+
425
+
426
+
427
+
428
+ self.query_768 = nn.Embedding(144, hidden_dimension)
429
+ self.query_1024 = nn.Embedding(256, hidden_dimension)
430
+
431
+
432
+ # self.query_refixation = nn.Embedding(int(math.sqrt(max_query)), hidden_dimension)
433
+
434
+
435
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
436
+ x = x.flatten(2).transpose(1, 2)
437
+
438
+ bs, n_query, _ = x.shape
439
+
440
+ if n_query == 144:
441
+ param_img = self.query_768.weight
442
+ elif n_query == 256:
443
+ param_img = self.query_1024.weight
444
+
445
+ batch_query_imgs = param_img.unsqueeze(0).expand(
446
+ bs, -1, -1
447
+ ) # (batch_size, num_queries, hidden_size)
448
+
449
+
450
+
451
+ x_combined = torch.cat([x, batch_query_imgs], dim=1)
452
+
453
+ token_type_ids = torch.cat([
454
+ torch.zeros(bs, n_query, dtype=torch.long),
455
+ torch.ones(bs, n_query, dtype=torch.long),
456
+ ], dim=1)
457
+
458
+
459
+ y = self.model(x_combined, token_type_ids)[0]
460
+
461
+
462
+ y = y[:, n_query:, :] # causal flow query
463
+
464
+
465
+ return y
466
+
467
+
468
+ def build_qwen2_decoder_as_encoder(
469
+ decoder_layer=24,
470
+ hidden_dimension=896,
471
+ num_attention_heads=14,
472
+ num_key_value_heads=2,
473
+ intermediate_size=4864,
474
+ max_query = 400,
475
+ checkpoint=None,
476
+ ):
477
+
478
+ decoder_as_encoder = Qwen2Decoder2Encoder(
479
+ decoder_layer=decoder_layer,
480
+ hidden_dimension = hidden_dimension,
481
+ num_attention_heads = num_attention_heads,
482
+ num_key_value_heads = num_key_value_heads,
483
+ intermediate_size = intermediate_size,
484
+ max_query = max_query
485
+ )
486
+
487
+
488
+
489
+
490
+ if checkpoint is not None:
491
+ # with open(checkpoint, "rb") as f:
492
+ state_dict = torch.load(checkpoint)
493
+
494
+ decoder_as_encoder.load_state_dict(state_dict, strict=True)
495
+ # tob
496
+ print(checkpoint)
497
+ return decoder_as_encoder
498
+
499
+
500
+
501
+
502
+ #=========================Sam-Vary=================================
503
+
504
+
505
+ def get_abs_pos_sam(abs_pos, tgt_size):
506
+
507
+ dtype = abs_pos.dtype
508
+
509
+ src_size = abs_pos.size(1)
510
+
511
+ if src_size != tgt_size:
512
+ old_pos_embed = abs_pos.permute(0, 3, 1, 2)
513
+ old_pos_embed = old_pos_embed.to(torch.float32)
514
+ new_pos_embed = F.interpolate(
515
+ old_pos_embed,
516
+ size=(tgt_size, tgt_size),
517
+ mode='bicubic',
518
+ antialias=True,
519
+ align_corners=False,
520
+ ).to(dtype)
521
+ new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
522
+ return new_pos_embed
523
+ else:
524
+ return abs_pos
525
+
526
+
527
+
528
+
529
+ class MLPBlock(nn.Module):
530
+ def __init__(
531
+ self,
532
+ embedding_dim: int,
533
+ mlp_dim: int,
534
+ act: Type[nn.Module] = nn.GELU,
535
+ ) -> None:
536
+ super().__init__()
537
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
538
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
539
+ self.act = act()
540
+
541
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
542
+ return self.lin2(self.act(self.lin1(x)))
543
+
544
+
545
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
546
+ # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
547
+ class LayerNorm2d(nn.Module):
548
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
549
+ super().__init__()
550
+ self.weight = nn.Parameter(torch.ones(num_channels))
551
+ self.bias = nn.Parameter(torch.zeros(num_channels))
552
+ self.eps = eps
553
+
554
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
555
+ u = x.mean(1, keepdim=True)
556
+ s = (x - u).pow(2).mean(1, keepdim=True)
557
+ x = (x - u) / torch.sqrt(s + self.eps)
558
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
559
+ return x
560
+
561
+
562
+ # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
563
+ class ImageEncoderViT(nn.Module):
564
+ def __init__(
565
+ self,
566
+ img_size: int = 1024,
567
+ patch_size: int = 16,
568
+ in_chans: int = 3,
569
+ embed_dim: int = 768,
570
+ depth: int = 12,
571
+ num_heads: int = 12,
572
+ mlp_ratio: float = 4.0,
573
+ out_chans: int = 256,
574
+ qkv_bias: bool = True,
575
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
576
+ act_layer: Type[nn.Module] = nn.GELU,
577
+ use_abs_pos: bool = True,
578
+ use_rel_pos: bool = False,
579
+ rel_pos_zero_init: bool = True,
580
+ window_size: int = 0,
581
+ global_attn_indexes: Tuple[int, ...] = (),
582
+ ) -> None:
583
+ """
584
+ Args:
585
+ img_size (int): Input image size.
586
+ patch_size (int): Patch size.
587
+ in_chans (int): Number of input image channels.
588
+ embed_dim (int): Patch embedding dimension.
589
+ depth (int): Depth of ViT.
590
+ num_heads (int): Number of attention heads in each ViT block.
591
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
592
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
593
+ norm_layer (nn.Module): Normalization layer.
594
+ act_layer (nn.Module): Activation layer.
595
+ use_abs_pos (bool): If True, use absolute positional embeddings.
596
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
597
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
598
+ window_size (int): Window size for window attention blocks.
599
+ global_attn_indexes (list): Indexes for blocks using global attention.
600
+ """
601
+ super().__init__()
602
+ self.img_size = img_size
603
+
604
+ self.patch_embed = PatchEmbed(
605
+ kernel_size=(patch_size, patch_size),
606
+ stride=(patch_size, patch_size),
607
+ in_chans=in_chans,
608
+ embed_dim=embed_dim,
609
+ )
610
+
611
+ self.pos_embed: Optional[nn.Parameter] = None
612
+ if use_abs_pos:
613
+ # Initialize absolute positional embedding with pretrain image size.
614
+ self.pos_embed = nn.Parameter(
615
+ torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
616
+ )
617
+
618
+ self.blocks = nn.ModuleList()
619
+ for i in range(depth):
620
+ block = Block(
621
+ dim=embed_dim,
622
+ num_heads=num_heads,
623
+ mlp_ratio=mlp_ratio,
624
+ qkv_bias=qkv_bias,
625
+ norm_layer=norm_layer,
626
+ act_layer=act_layer,
627
+ use_rel_pos=use_rel_pos,
628
+ rel_pos_zero_init=rel_pos_zero_init,
629
+ window_size=window_size if i not in global_attn_indexes else 0,
630
+ input_size=(img_size // patch_size, img_size // patch_size),
631
+ )
632
+ self.blocks.append(block)
633
+
634
+ self.neck = nn.Sequential(
635
+ nn.Conv2d(
636
+ embed_dim,
637
+ out_chans,
638
+ kernel_size=1,
639
+ bias=False,
640
+ ),
641
+ LayerNorm2d(out_chans),
642
+ nn.Conv2d(
643
+ out_chans,
644
+ out_chans,
645
+ kernel_size=3,
646
+ padding=1,
647
+ bias=False,
648
+ ),
649
+ LayerNorm2d(out_chans),
650
+ )
651
+
652
+ self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
653
+ self.net_3 = nn.Conv2d(512, 896, kernel_size=3, stride=2, padding=1, bias=False)
654
+
655
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
656
+ x = self.patch_embed(x)
657
+ if self.pos_embed is not None:
658
+ # x = x + self.pos_embed
659
+ x = x + get_abs_pos_sam(self.pos_embed, x.size(1))
660
+
661
+ for blk in self.blocks:
662
+ x = blk(x)
663
+
664
+ x = self.neck(x.permute(0, 3, 1, 2))
665
+ x2 = self.net_2(x)
666
+ x3 = self.net_3(x2.clone())
667
+
668
+ return x3
669
+
670
+
671
+ class Block(nn.Module):
672
+ """Transformer blocks with support of window attention and residual propagation blocks"""
673
+
674
+ def __init__(
675
+ self,
676
+ dim: int,
677
+ num_heads: int,
678
+ mlp_ratio: float = 4.0,
679
+ qkv_bias: bool = True,
680
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
681
+ act_layer: Type[nn.Module] = nn.GELU,
682
+ use_rel_pos: bool = False,
683
+ rel_pos_zero_init: bool = True,
684
+ window_size: int = 0,
685
+ input_size: Optional[Tuple[int, int]] = None,
686
+ ) -> None:
687
+ """
688
+ Args:
689
+ dim (int): Number of input channels.
690
+ num_heads (int): Number of attention heads in each ViT block.
691
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
692
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
693
+ norm_layer (nn.Module): Normalization layer.
694
+ act_layer (nn.Module): Activation layer.
695
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
696
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
697
+ window_size (int): Window size for window attention blocks. If it equals 0, then
698
+ use global attention.
699
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
700
+ positional parameter size.
701
+ """
702
+ super().__init__()
703
+ self.norm1 = norm_layer(dim)
704
+ self.attn = Attention(
705
+ dim,
706
+ num_heads=num_heads,
707
+ qkv_bias=qkv_bias,
708
+ use_rel_pos=use_rel_pos,
709
+ rel_pos_zero_init=rel_pos_zero_init,
710
+ input_size=input_size if window_size == 0 else (window_size, window_size),
711
+ )
712
+
713
+ self.norm2 = norm_layer(dim)
714
+ self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
715
+
716
+ self.window_size = window_size
717
+
718
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
719
+ shortcut = x
720
+ x = self.norm1(x)
721
+ # Window partition
722
+ if self.window_size > 0:
723
+ H, W = x.shape[1], x.shape[2]
724
+ x, pad_hw = window_partition(x, self.window_size)
725
+
726
+ x = self.attn(x)
727
+ # Reverse window partition
728
+ if self.window_size > 0:
729
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
730
+
731
+ x = shortcut + x
732
+ x = x + self.mlp(self.norm2(x))
733
+
734
+ return x
735
+
736
+
737
+ class Attention(nn.Module):
738
+ """Multi-head Attention block with relative position embeddings."""
739
+
740
+ def __init__(
741
+ self,
742
+ dim: int,
743
+ num_heads: int = 8,
744
+ qkv_bias: bool = True,
745
+ use_rel_pos: bool = False,
746
+ rel_pos_zero_init: bool = True,
747
+ input_size: Optional[Tuple[int, int]] = None,
748
+ ) -> None:
749
+ """
750
+ Args:
751
+ dim (int): Number of input channels.
752
+ num_heads (int): Number of attention heads.
753
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
754
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
755
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
756
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
757
+ positional parameter size.
758
+ """
759
+ super().__init__()
760
+ self.num_heads = num_heads
761
+ head_dim = dim // num_heads
762
+ self.scale = head_dim**-0.5
763
+
764
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
765
+ self.proj = nn.Linear(dim, dim)
766
+
767
+ self.use_rel_pos = use_rel_pos
768
+ if self.use_rel_pos:
769
+ assert (
770
+ input_size is not None
771
+ ), "Input size must be provided if using relative positional encoding."
772
+ # initialize relative positional embeddings
773
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
774
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
775
+
776
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
777
+ B, H, W, _ = x.shape
778
+ # qkv with shape (3, B, nHead, H * W, C)
779
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
780
+ # q, k, v with shape (B * nHead, H * W, C)
781
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
782
+
783
+ rel_h, rel_w = None, None
784
+ if self.use_rel_pos:
785
+ rel_h, rel_w = add_decomposed_rel_pos(q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
786
+
787
+ q = q.view(B, self.num_heads, H * W, -1)
788
+ k = k.view(B, self.num_heads, H * W, -1)
789
+ v = v.view(B, self.num_heads, H * W, -1)
790
+
791
+ if self.use_rel_pos:
792
+ rel_h = rel_h.view(B, self.num_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3))
793
+ rel_w = rel_w.view(B, self.num_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3))
794
+ attn_bias = (rel_h + rel_w).view(B, self.num_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4))
795
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
796
+ # x = _attention_rel_h_rel_w(q, k, v, rel_h, rel_w)
797
+ else:
798
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
799
+
800
+ x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
801
+
802
+ x = self.proj(x)
803
+
804
+ return x
805
+
806
+
807
+ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
808
+ """
809
+ Partition into non-overlapping windows with padding if needed.
810
+ Args:
811
+ x (tensor): input tokens with [B, H, W, C].
812
+ window_size (int): window size.
813
+
814
+ Returns:
815
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
816
+ (Hp, Wp): padded height and width before partition
817
+ """
818
+ B, H, W, C = x.shape
819
+
820
+ pad_h = (window_size - H % window_size) % window_size
821
+ pad_w = (window_size - W % window_size) % window_size
822
+ if pad_h > 0 or pad_w > 0:
823
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
824
+ Hp, Wp = H + pad_h, W + pad_w
825
+
826
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
827
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
828
+ return windows, (Hp, Wp)
829
+
830
+
831
+ def window_unpartition(
832
+ windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
833
+ ) -> torch.Tensor:
834
+ """
835
+ Window unpartition into original sequences and removing padding.
836
+ Args:
837
+ windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
838
+ window_size (int): window size.
839
+ pad_hw (Tuple): padded height and width (Hp, Wp).
840
+ hw (Tuple): original height and width (H, W) before padding.
841
+
842
+ Returns:
843
+ x: unpartitioned sequences with [B, H, W, C].
844
+ """
845
+ Hp, Wp = pad_hw
846
+ H, W = hw
847
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
848
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
849
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
850
+
851
+ if Hp > H or Wp > W:
852
+ x = x[:, :H, :W, :].contiguous()
853
+ return x
854
+
855
+
856
+ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
857
+ """
858
+ Get relative positional embeddings according to the relative positions of
859
+ query and key sizes.
860
+ Args:
861
+ q_size (int): size of query q.
862
+ k_size (int): size of key k.
863
+ rel_pos (Tensor): relative position embeddings (L, C).
864
+
865
+ Returns:
866
+ Extracted positional embeddings according to relative positions.
867
+ """
868
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
869
+ # Interpolate rel pos if needed.
870
+ if rel_pos.shape[0] != max_rel_dist:
871
+ # Interpolate rel pos.
872
+ dtype = rel_pos.dtype
873
+ rel_pos = rel_pos.to(torch.float32)
874
+ rel_pos_resized = F.interpolate(
875
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
876
+ size=max_rel_dist,
877
+ mode="linear",
878
+ ).to(dtype)
879
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
880
+ else:
881
+ rel_pos_resized = rel_pos
882
+
883
+ # Scale the coords with short length if shapes for q and k are different.
884
+ q_coords = torch.arange(q_size, device=rel_pos.device)[:, None] * max(k_size / q_size, 1.0)
885
+ k_coords = torch.arange(k_size, device=rel_pos.device)[None, :] * max(q_size / k_size, 1.0)
886
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
887
+
888
+ return rel_pos_resized[relative_coords.long()]
889
+
890
+
891
+ def add_decomposed_rel_pos(
892
+ q: torch.Tensor,
893
+ rel_pos_h: torch.Tensor,
894
+ rel_pos_w: torch.Tensor,
895
+ q_size: Tuple[int, int],
896
+ k_size: Tuple[int, int],
897
+ ) -> torch.Tensor:
898
+ """
899
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
900
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
901
+ Args:
902
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
903
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
904
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
905
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
906
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
907
+
908
+ Returns:
909
+ attn (Tensor): attention map with added relative positional embeddings.
910
+ """
911
+ q_h, q_w = q_size
912
+ k_h, k_w = k_size
913
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
914
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
915
+
916
+ B, _, dim = q.shape
917
+ r_q = q.reshape(B, q_h, q_w, dim)
918
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
919
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
920
+ rel_h = rel_h.unsqueeze(-1)
921
+ rel_w = rel_w.unsqueeze(-2)
922
+ rel_h = rel_h.reshape(B, q_h * q_w, k_h, 1)
923
+ rel_w = rel_w.reshape(B, q_h * q_w, 1, k_w)
924
+
925
+ return rel_h, rel_w
926
+
927
+
928
+ class PatchEmbed(nn.Module):
929
+ """
930
+ Image to Patch Embedding.
931
+ """
932
+
933
+ def __init__(
934
+ self,
935
+ kernel_size: Tuple[int, int] = (16, 16),
936
+ stride: Tuple[int, int] = (16, 16),
937
+ padding: Tuple[int, int] = (0, 0),
938
+ in_chans: int = 3,
939
+ embed_dim: int = 768,
940
+ ) -> None:
941
+ """
942
+ Args:
943
+ kernel_size (Tuple): kernel size of the projection layer.
944
+ stride (Tuple): stride of the projection layer.
945
+ padding (Tuple): padding size of the projection layer.
946
+ in_chans (int): Number of input image channels.
947
+ embed_dim (int): Patch embedding dimension.
948
+ """
949
+ super().__init__()
950
+
951
+ self.proj = nn.Conv2d(
952
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
953
+ )
954
+
955
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
956
+ x = self.proj(x)
957
+ # B C H W -> B H W C
958
+ x = x.permute(0, 2, 3, 1)
959
+ return x
960
+
961
+
962
+ def build_sam_vit_b(checkpoint=None):
963
+ return _build_sam(
964
+ encoder_embed_dim=768,
965
+ encoder_depth=12,
966
+ encoder_num_heads=12,
967
+ encoder_global_attn_indexes=[2, 5, 8, 11],
968
+ checkpoint=checkpoint,
969
+ )
970
+
971
+ def build_sam_fast_vit_b(checkpoint=None, compile_mode='max-autotune', dtype=torch.bfloat16):
972
+ image_encoder = build_sam_vit_b(checkpoint).eval().to(dtype)
973
+ # sam = _apply_eval_dtype_sam(sam, dtype)
974
+ image_encoder = torch.compile(image_encoder, mode=compile_mode)
975
+ return image_encoder
976
+
977
+
978
+ def _build_sam(
979
+ encoder_embed_dim,
980
+ encoder_depth,
981
+ encoder_num_heads,
982
+ encoder_global_attn_indexes,
983
+ checkpoint=None,
984
+ ):
985
+ prompt_embed_dim = 256
986
+ image_size = 1024
987
+ vit_patch_size = 16
988
+ image_embedding_size = image_size // vit_patch_size
989
+ image_encoder=ImageEncoderViT(
990
+ depth=encoder_depth,
991
+ embed_dim=encoder_embed_dim,
992
+ img_size=image_size,
993
+ mlp_ratio=4,
994
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
995
+ num_heads=encoder_num_heads,
996
+ patch_size=vit_patch_size,
997
+ qkv_bias=True,
998
+ use_rel_pos=True,
999
+ global_attn_indexes=encoder_global_attn_indexes,
1000
+ window_size=14,
1001
+ out_chans=prompt_embed_dim,
1002
+ )
1003
+ image_encoder.eval()
1004
+ if checkpoint is not None:
1005
+ # with open(checkpoint, "rb") as f:
1006
+ state_dict = torch.load(checkpoint)
1007
+ # print(state_dict.keys())
1008
+ # for key in state_dict:
1009
+ # image_encoder.load_state_dict({k[14:]: v for k, v in state_dict.items() if 'image_encoder' in k}, strict=False)
1010
+ # ocr-anyting
1011
+ # image_encoder.load_state_dict(state_dict, strict=True)
1012
+ # tob
1013
+ image_encoder.load_state_dict({k[30:]: v for k, v in state_dict.items() if 'vision_tower_high' in k}, strict=True)
1014
+ print(checkpoint)
1015
+ return image_encoder
LICENSE.txt ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright (c) 2023 DeepSeek
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
README.md CHANGED
@@ -7,7 +7,7 @@ tags:
7
  - vision-language
8
  - ocr
9
  - custom_code
10
- license: apache-2.0
11
  library_name: transformers
12
  ---
13
  <div align="center">
 
7
  - vision-language
8
  - ocr
9
  - custom_code
10
+ license: mit
11
  library_name: transformers
12
  ---
13
  <div align="center">