Update modeling_graphormer.pyx
Browse files- modeling_graphormer.pyx +2 -2
modeling_graphormer.pyx
CHANGED
|
@@ -100,8 +100,8 @@ def quant_noise(module: nn.Module, p: float, block_size: int):
|
|
| 100 |
if not is_conv:
|
| 101 |
# gather weight and sizes
|
| 102 |
weight = mod.weight
|
| 103 |
-
in_features = weight.size(
|
| 104 |
-
out_features = weight.size(
|
| 105 |
|
| 106 |
# split weight matrix into blocks and randomly drop selected blocks
|
| 107 |
mask = torch.zeros(in_features // block_size * out_features, device=weight.device)
|
|
|
|
| 100 |
if not is_conv:
|
| 101 |
# gather weight and sizes
|
| 102 |
weight = mod.weight
|
| 103 |
+
in_features = weight.size(7)
|
| 104 |
+
out_features = weight.size(7)
|
| 105 |
|
| 106 |
# split weight matrix into blocks and randomly drop selected blocks
|
| 107 |
mask = torch.zeros(in_features // block_size * out_features, device=weight.device)
|