Munaza10 commited on
Commit
f277f76
·
verified ·
1 Parent(s): ae0b038

Delete siglip2.py

Browse files
Files changed (1) hide show
  1. siglip2.py +0 -564
siglip2.py DELETED
@@ -1,564 +0,0 @@
1
- # Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License");
2
- # you may not use this file except in compliance with the License.
3
- # You may obtain a copy of the License at
4
- #
5
- # https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE
6
- #
7
- # Unless required by applicable law or agreed to in writing, software
8
- # distributed under the License is distributed on an "AS IS" BASIS,
9
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
- # See the License for the specific language governing permissions and
11
- # limitations under the License.
12
- # ==============================================================================
13
- #
14
- # Copyright 2025 The HuggingFace Inc. team.
15
- #
16
- # Licensed under the Apache License, Version 2.0 (the "License");
17
- # you may not use this file except in compliance with the License.
18
- # You may obtain a copy of the License at
19
- #
20
- # http://www.apache.org/licenses/LICENSE-2.0
21
- #
22
- # Unless required by applicable law or agreed to in writing, software
23
- # distributed under the License is distributed on an "AS IS" BASIS,
24
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25
- # See the License for the specific language governing permissions and
26
- # limitations under the License.
27
- # ==============================================================================
28
-
29
- from typing import Optional, Tuple, Union
30
- import warnings
31
-
32
- import torch
33
- import torch.nn as nn
34
- import torch.nn.functional as F
35
-
36
- from transformers.activations import ACT2FN
37
- from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
38
- from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
39
-
40
-
41
- class Config(object):
42
- def __init__(self, config):
43
- if config is not None:
44
- for key, value in config.items():
45
- setattr(self, key, value)
46
-
47
- def __getitem__(self, key):
48
- return getattr(self, key, None)
49
-
50
- def __setitem__(self, key, value):
51
- return setattr(self, key, value)
52
-
53
-
54
- class Siglip2VisionEmbeddings(nn.Module):
55
- def __init__(self, config):
56
- super().__init__()
57
- self.config = config
58
- self.embed_dim = config.hidden_size
59
- self.patch_size = config.patch_size
60
-
61
- self.patch_embedding = nn.Linear(
62
- in_features=config.num_channels * self.patch_size * self.patch_size,
63
- out_features=self.embed_dim,
64
- )
65
-
66
- self.num_patches = config.num_patches
67
- self.position_embedding_size = int(self.num_patches**0.5)
68
- self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
69
-
70
- @staticmethod
71
- def resize_positional_embeddings(
72
- positional_embeddings: torch.Tensor,
73
- spatial_shapes: torch.LongTensor,
74
- max_length: int,
75
- ) -> torch.Tensor:
76
- """
77
- Resize positional embeddings to image-specific size and pad to a fixed size.
78
-
79
- Args:
80
- positional_embeddings (`torch.Tensor`):
81
- Position embeddings of shape (height, width, embed_dim)
82
- spatial_shapes (`torch.LongTensor`):
83
- Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
84
- max_length (`int`):
85
- Maximum length of the positional embeddings to pad resized positional embeddings to
86
-
87
- Returns:
88
- `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim)
89
- """
90
- batch_size = spatial_shapes.shape[0]
91
- embed_dim = positional_embeddings.shape[-1]
92
- source_dtype = positional_embeddings.dtype
93
-
94
- resulted_positional_embeddings = torch.empty(
95
- (batch_size, max_length, embed_dim),
96
- device=positional_embeddings.device,
97
- dtype=source_dtype,
98
- )
99
-
100
- # (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation
101
- positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0)
102
-
103
- # Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU
104
- if positional_embeddings.device.type == "cpu":
105
- positional_embeddings = positional_embeddings.to(torch.float32)
106
-
107
- for i in range(batch_size):
108
- # (1, dim, height, width) -> (1, dim, target_height, target_width)
109
- height, width = spatial_shapes[i]
110
- resized_embeddings = F.interpolate(
111
- positional_embeddings,
112
- size=(height, width),
113
- mode="bilinear",
114
- align_corners=False,
115
- antialias=True,
116
- )
117
-
118
- # (1, dim, target_height, target_width) -> (target_height * target_width, dim)
119
- resized_embeddings = resized_embeddings.reshape(embed_dim, height * width).transpose(0, 1)
120
-
121
- # Cast to original dtype
122
- resized_embeddings = resized_embeddings.to(source_dtype)
123
-
124
- resulted_positional_embeddings[i, : height * width] = resized_embeddings
125
- resulted_positional_embeddings[i, height * width :] = resized_embeddings[0]
126
-
127
- return resulted_positional_embeddings
128
-
129
- def forward(self, pixel_values: torch.FloatTensor, spatial_shapes: torch.LongTensor) -> torch.Tensor:
130
- """
131
- Args:
132
- pixel_values (`torch.FloatTensor`):
133
- Pixel values of shape (batch_size, max_num_patches, num_channels * patch_size * patch_size)
134
- spatial_shapes (`List[Tuple[int, int]]`):
135
- Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
136
- """
137
-
138
- # Apply patch embeddings to already patchified pixel values
139
- target_dtype = self.patch_embedding.weight.dtype
140
- patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
141
-
142
- # Get positional resized and padded positional embeddings
143
- positional_embeddings = self.position_embedding.weight.reshape(
144
- self.position_embedding_size, self.position_embedding_size, -1
145
- )
146
- resized_positional_embeddings = self.resize_positional_embeddings(
147
- positional_embeddings, spatial_shapes, max_length=pixel_values.shape[1]
148
- )
149
-
150
- # Add positional embeddings to patch embeddings
151
- embeddings = patch_embeds + resized_positional_embeddings
152
- return embeddings
153
-
154
-
155
- class Siglip2Attention(nn.Module):
156
- """Multi-headed attention from 'Attention Is All You Need' paper"""
157
-
158
- def __init__(self, config):
159
- super().__init__()
160
- self.config = config
161
- self.embed_dim = config.hidden_size
162
- self.num_heads = config.num_attention_heads
163
- self.head_dim = self.embed_dim // self.num_heads
164
- if self.head_dim * self.num_heads != self.embed_dim:
165
- raise ValueError(
166
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
167
- f" {self.num_heads})."
168
- )
169
- self.scale = self.head_dim**-0.5
170
- self.dropout = config.attention_dropout
171
-
172
- self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
173
- self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
174
- self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
175
- self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
176
-
177
- def forward(
178
- self,
179
- hidden_states: torch.Tensor,
180
- attention_mask: Optional[torch.Tensor] = None,
181
- output_attentions: Optional[bool] = False,
182
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
183
- """Input shape: Batch x Time x Channel"""
184
-
185
- batch_size, q_len, _ = hidden_states.size()
186
-
187
- query_states = self.q_proj(hidden_states)
188
- key_states = self.k_proj(hidden_states)
189
- value_states = self.v_proj(hidden_states)
190
-
191
- query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
192
- key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
193
- value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
194
-
195
- k_v_seq_len = key_states.shape[-2]
196
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
197
-
198
- if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
199
- raise ValueError(
200
- f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
201
- f" {attn_weights.size()}"
202
- )
203
-
204
- if attention_mask is not None:
205
- if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
206
- raise ValueError(
207
- f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, "
208
- f"but is {attention_mask.size()}"
209
- )
210
- attn_weights = attn_weights + attention_mask
211
-
212
- # upcast attention to fp32
213
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
214
- attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
215
- attn_output = torch.matmul(attn_weights, value_states)
216
-
217
- if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
218
- raise ValueError(
219
- f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
220
- f" {attn_output.size()}"
221
- )
222
-
223
- attn_output = attn_output.transpose(1, 2).contiguous()
224
- attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
225
-
226
- attn_output = self.out_proj(attn_output)
227
-
228
- return attn_output, attn_weights
229
-
230
- class Siglip2SdpaAttention(Siglip2Attention):
231
- """
232
- Siglip2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
233
- `Siglip2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt
234
- to SDPA API.
235
- """
236
-
237
- is_causal = False
238
-
239
- # Adapted from Siglip2Attention.forward and transformers.models.llama.modeling_llama.LlamaSdpaAttention.forward
240
- def forward(
241
- self,
242
- hidden_states: torch.Tensor,
243
- attention_mask: Optional[torch.Tensor] = None,
244
- output_attentions: Optional[bool] = False,
245
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
246
- if output_attentions:
247
- # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"`
248
- # once this is implemented.
249
- warnings.warn(
250
- "Siglip2Model is using Siglip2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` "
251
- "does not support `output_attentions=True`. Falling back to the manual attention implementation, "
252
- 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. '
253
- 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
254
- )
255
- return super().forward(
256
- hidden_states=hidden_states,
257
- attention_mask=attention_mask,
258
- output_attentions=output_attentions,
259
- )
260
-
261
- batch_size, q_len, _ = hidden_states.size()
262
-
263
- query_states = self.q_proj(hidden_states)
264
- key_states = self.k_proj(hidden_states)
265
- value_states = self.v_proj(hidden_states)
266
-
267
- query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
268
- key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
269
- value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
270
-
271
- # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with
272
- # custom attn_mask,
273
- # Reference: https://github.com/pytorch/pytorch/issues/112577.
274
- if query_states.device.type == "cuda" and attention_mask is not None:
275
- query_states = query_states.contiguous()
276
- key_states = key_states.contiguous()
277
- value_states = value_states.contiguous()
278
-
279
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an
280
- # inline conditional assignment in SDPA to support both torch.compile's dynamic shapes and full graph options.
281
- # An inline conditional prevents dynamic shapes from compiling.
282
- is_causal = True if self.is_causal and q_len > 1 else False
283
-
284
- attn_output = torch.nn.functional.scaled_dot_product_attention(
285
- query_states,
286
- key_states,
287
- value_states,
288
- attn_mask=attention_mask,
289
- dropout_p=self.dropout if self.training else 0.0,
290
- is_causal=is_causal,
291
- )
292
-
293
- attn_output = attn_output.transpose(1, 2).contiguous()
294
- attn_output = attn_output.view(batch_size, q_len, self.embed_dim)
295
-
296
- attn_output = self.out_proj(attn_output)
297
-
298
- return attn_output, None
299
-
300
-
301
- class Siglip2MLP(nn.Module):
302
- def __init__(self, config):
303
- super().__init__()
304
- self.config = config
305
- self.activation_fn = ACT2FN[config.hidden_act]
306
- self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
307
- self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
308
-
309
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
310
- hidden_states = self.fc1(hidden_states)
311
- hidden_states = self.activation_fn(hidden_states)
312
- hidden_states = self.fc2(hidden_states)
313
- return hidden_states
314
-
315
-
316
- class Siglip2EncoderLayer(nn.Module):
317
- def __init__(self, config):
318
- super().__init__()
319
- self.embed_dim = config.hidden_size
320
- self.self_attn = Siglip2Attention(config=config)
321
- self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
322
- self.mlp = Siglip2MLP(config)
323
- self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
324
-
325
- # Ignore copy
326
- def forward(
327
- self,
328
- hidden_states: torch.Tensor,
329
- attention_mask: torch.Tensor,
330
- output_attentions: Optional[bool] = False,
331
- ) -> Tuple[torch.FloatTensor]:
332
- """
333
- Args:
334
- hidden_states (`torch.FloatTensor`):
335
- Input to the layer of shape `(batch, seq_len, embed_dim)`.
336
- attention_mask (`torch.FloatTensor`):
337
- Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very
338
- large negative values.
339
- output_attentions (`bool`, *optional*, defaults to `False`):
340
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
341
- returned tensors for more detail.
342
- """
343
- residual = hidden_states
344
-
345
- hidden_states = self.layer_norm1(hidden_states)
346
- hidden_states, attn_weights = self.self_attn(
347
- hidden_states=hidden_states,
348
- attention_mask=attention_mask,
349
- output_attentions=output_attentions,
350
- )
351
- hidden_states = residual + hidden_states
352
-
353
- residual = hidden_states
354
- hidden_states = self.layer_norm2(hidden_states)
355
- hidden_states = self.mlp(hidden_states)
356
- hidden_states = residual + hidden_states
357
-
358
- outputs = (hidden_states,)
359
-
360
- if output_attentions:
361
- outputs += (attn_weights,)
362
-
363
- return outputs
364
-
365
-
366
- class Siglip2Encoder(nn.Module):
367
- """
368
- Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
369
- [`Siglip2EncoderLayer`].
370
-
371
- Args:
372
- config: Siglip2Config
373
- """
374
-
375
- def __init__(self, config):
376
- super().__init__()
377
- self.config = config
378
- self.layers = nn.ModuleList([Siglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
379
- self.gradient_checkpointing = True
380
-
381
- # Ignore copy
382
- def forward(
383
- self,
384
- inputs_embeds,
385
- attention_mask: Optional[torch.Tensor] = None,
386
- output_attentions: Optional[bool] = None,
387
- output_hidden_states: Optional[bool] = None,
388
- return_dict: Optional[bool] = None,
389
- ) -> Union[Tuple, BaseModelOutput]:
390
- r"""
391
- Args:
392
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
393
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
394
- This is useful if you want more control over how to convert `input_ids` indices into associated vectors
395
- than the model's internal embedding lookup matrix.
396
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
397
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
398
-
399
- - 1 for tokens that are **not masked**,
400
- - 0 for tokens that are **masked**.
401
-
402
- [What are attention masks?](../glossary#attention-mask)
403
- output_attentions (`bool`, *optional*):
404
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
405
- returned tensors for more detail.
406
- output_hidden_states (`bool`, *optional*):
407
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
408
- for more detail.
409
- return_dict (`bool`, *optional*):
410
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
411
- """
412
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
413
- output_hidden_states = (
414
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
415
- )
416
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
417
-
418
- encoder_states = () if output_hidden_states else None
419
- all_attentions = () if output_attentions else None
420
-
421
- hidden_states = inputs_embeds
422
- for layer_index, encoder_layer in enumerate(self.layers): # len(self.layers): 27
423
- if output_hidden_states:
424
- encoder_states = encoder_states + (hidden_states,)
425
-
426
- layer_outputs = encoder_layer(
427
- hidden_states,
428
- attention_mask,
429
- output_attentions=output_attentions,
430
- )
431
-
432
- hidden_states = layer_outputs[0]
433
-
434
- if output_attentions:
435
- all_attentions = all_attentions + (layer_outputs[1],)
436
-
437
- if output_hidden_states:
438
- encoder_states = encoder_states + (hidden_states,)
439
-
440
- if not return_dict:
441
- return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
442
- return BaseModelOutput(
443
- last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
444
- )
445
-
446
-
447
- class Siglip2MultiheadAttentionPoolingHead(nn.Module):
448
- """Multihead Attention Pooling."""
449
-
450
- def __init__(self, config):
451
- super().__init__()
452
-
453
- self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
454
- self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
455
- self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
456
- self.mlp = Siglip2MLP(config)
457
- self.num_heads = config.num_attention_heads
458
-
459
- def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
460
- batch_size = hidden_state.shape[0]
461
- probe = self.probe.repeat(batch_size, 1, 1)
462
-
463
- if attention_mask is not None:
464
- target_len, source_len = probe.shape[1], hidden_state.shape[1]
465
- attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_state.dtype, target_len)
466
- attention_mask = attention_mask.repeat(1, self.num_heads, target_len, 1)
467
- attention_mask = attention_mask.reshape(-1, target_len, source_len)
468
-
469
- hidden_state = self.attention(probe, hidden_state, hidden_state, attn_mask=attention_mask)[0]
470
-
471
- residual = hidden_state
472
- hidden_state = self.layernorm(hidden_state)
473
- hidden_state = residual + self.mlp(hidden_state)
474
-
475
- return hidden_state[:, 0]
476
-
477
-
478
- class Siglip2VisionTransformer(nn.Module):
479
- def __init__(self, config):
480
- super().__init__()
481
- config = Config(config)
482
- self.config = config
483
- embed_dim = config.hidden_size
484
-
485
- self.embeddings = Siglip2VisionEmbeddings(config)
486
- self.encoder = Siglip2Encoder(config)
487
- self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
488
- self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head
489
- if self.use_head:
490
- self.head = Siglip2MultiheadAttentionPoolingHead(config)
491
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
492
-
493
- def forward(
494
- self,
495
- pixel_values: torch.FloatTensor,
496
- attention_mask: torch.Tensor,
497
- spatial_shapes: torch.LongTensor,
498
- output_attentions: Optional[bool] = None,
499
- output_hidden_states: Optional[bool] = None,
500
- return_dict: Optional[bool] = None,
501
- ) -> Union[Tuple, BaseModelOutputWithPooling]:
502
- r"""
503
- Returns:
504
-
505
- """
506
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
507
- output_hidden_states = (
508
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
509
- )
510
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
511
-
512
- hidden_states = self.embeddings(pixel_values, spatial_shapes)
513
-
514
- if attention_mask is not None and not self._use_flash_attention_2:
515
- # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
516
- encoder_attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
517
- else:
518
- encoder_attention_mask = attention_mask
519
-
520
- encoder_outputs = self.encoder(
521
- inputs_embeds=hidden_states,
522
- attention_mask=encoder_attention_mask,
523
- output_attentions=output_attentions,
524
- output_hidden_states=output_hidden_states,
525
- return_dict=return_dict,
526
- )
527
-
528
- last_hidden_state = encoder_outputs[0]
529
- last_hidden_state = self.post_layernorm(last_hidden_state)
530
-
531
- pooler_output = self.head(last_hidden_state, attention_mask) if self.use_head else None
532
- if not return_dict:
533
- return (last_hidden_state, pooler_output) + encoder_outputs[1:]
534
-
535
- return BaseModelOutputWithPooling(
536
- last_hidden_state=last_hidden_state,
537
- pooler_output=pooler_output,
538
- hidden_states=encoder_outputs.hidden_states,
539
- attentions=encoder_outputs.attentions,
540
- )
541
-
542
-
543
- class LightProjector(nn.Module):
544
- def __init__(self, config):
545
- config = Config(config)
546
- super().__init__()
547
-
548
- if config.projector_type == "linear":
549
- modules = nn.Linear(config.input_dim, config.n_embed)
550
-
551
- elif config.projector_type == "mlp_gelu":
552
- modules = [nn.Linear(config.input_dim, config.n_embed)]
553
- for _ in range(1, config.depth):
554
- modules.append(nn.GELU())
555
- modules.append(nn.Linear(config.n_embed, config.n_embed))
556
- modules = nn.Sequential(*modules)
557
-
558
- else:
559
- raise ValueError(f"Unknown projector type: {config.projector_type}")
560
-
561
- self.layers = modules
562
-
563
- def forward(self, x):
564
- return self.layers(x)