Add print statements
Browse files- modeling_cogvlm.py +27 -25
modeling_cogvlm.py
CHANGED
|
@@ -241,33 +241,35 @@ class VisionExpertAttention(nn.Module):
|
|
| 241 |
key_states = self._transpose_for_scores(key_states) # B, H, L, HD
|
| 242 |
value_states = self._transpose_for_scores(value_states) # B, H, L, HD
|
| 243 |
|
| 244 |
-
|
| 245 |
-
torch.save(key_states, "key_states.pt")
|
| 246 |
-
torch.save(value_states, "value_states.pt")
|
| 247 |
|
| 248 |
-
|
|
|
|
|
|
|
| 249 |
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
|
|
|
|
|
|
| 271 |
|
| 272 |
kv_seq_len = key_states.shape[-2]
|
| 273 |
if past_key_value is not None:
|
|
|
|
| 241 |
key_states = self._transpose_for_scores(key_states) # B, H, L, HD
|
| 242 |
value_states = self._transpose_for_scores(value_states) # B, H, L, HD
|
| 243 |
|
| 244 |
+
if print_values:
|
|
|
|
|
|
|
| 245 |
|
| 246 |
+
torch.save(query_states, "query_states.pt")
|
| 247 |
+
torch.save(key_states, "key_states.pt")
|
| 248 |
+
torch.save(value_states, "value_states.pt")
|
| 249 |
|
| 250 |
+
from huggingface_hub import HfApi
|
| 251 |
+
|
| 252 |
+
api = HfApi()
|
| 253 |
+
api.upload_file(
|
| 254 |
+
path_or_fileobj="query_states.pt",
|
| 255 |
+
path_in_repo="query_states.pt",
|
| 256 |
+
repo_id="nielsr/test-cogvlm",
|
| 257 |
+
repo_type="dataset",
|
| 258 |
+
)
|
| 259 |
+
api = HfApi()
|
| 260 |
+
api.upload_file(
|
| 261 |
+
path_or_fileobj="key_states.pt",
|
| 262 |
+
path_in_repo="key_states.pt",
|
| 263 |
+
repo_id="nielsr/test-cogvlm",
|
| 264 |
+
repo_type="dataset",
|
| 265 |
+
)
|
| 266 |
+
api = HfApi()
|
| 267 |
+
api.upload_file(
|
| 268 |
+
path_or_fileobj="value_states.pt",
|
| 269 |
+
path_in_repo="value_states.pt",
|
| 270 |
+
repo_id="nielsr/test-cogvlm",
|
| 271 |
+
repo_type="dataset",
|
| 272 |
+
)
|
| 273 |
|
| 274 |
kv_seq_len = key_states.shape[-2]
|
| 275 |
if past_key_value is not None:
|