make compatible with transformers 4.49+ (#43)
Browse files- make compatible with transformers 4.49+ (1a5be979ef53262299363e521777777d1f356869)
Co-authored-by: Blair Chintella <[email protected]>
- modeling_chatglm.py +13 -7
modeling_chatglm.py
CHANGED
|
@@ -1082,19 +1082,22 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1082 |
outputs: ModelOutput,
|
| 1083 |
model_kwargs: Dict[str, Any],
|
| 1084 |
is_encoder_decoder: bool = False,
|
|
|
|
| 1085 |
) -> Dict[str, Any]:
|
| 1086 |
-
|
| 1087 |
-
|
| 1088 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1089 |
|
| 1090 |
-
# update attention mask
|
| 1091 |
if "attention_mask" in model_kwargs:
|
| 1092 |
attention_mask = model_kwargs["attention_mask"]
|
| 1093 |
model_kwargs["attention_mask"] = torch.cat(
|
| 1094 |
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
| 1095 |
)
|
| 1096 |
-
|
| 1097 |
-
# update position ids
|
| 1098 |
if "position_ids" in model_kwargs:
|
| 1099 |
position_ids = model_kwargs["position_ids"]
|
| 1100 |
new_position_id = position_ids[..., -1:].clone()
|
|
@@ -1102,8 +1105,11 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1102 |
model_kwargs["position_ids"] = torch.cat(
|
| 1103 |
[position_ids, new_position_id], dim=-1
|
| 1104 |
)
|
| 1105 |
-
|
| 1106 |
model_kwargs["is_first_forward"] = False
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1107 |
return model_kwargs
|
| 1108 |
|
| 1109 |
def prepare_inputs_for_generation(
|
|
|
|
| 1082 |
outputs: ModelOutput,
|
| 1083 |
model_kwargs: Dict[str, Any],
|
| 1084 |
is_encoder_decoder: bool = False,
|
| 1085 |
+
num_new_tokens: int = 1,
|
| 1086 |
) -> Dict[str, Any]:
|
| 1087 |
+
for possible_cache_name in ["past_key_values", "mems", "past_buckets_states", "cache_params"]:
|
| 1088 |
+
if hasattr(outputs, possible_cache_name):
|
| 1089 |
+
if possible_cache_name in ("past_buckets_states", "mems"):
|
| 1090 |
+
cache_name = "past_key_values"
|
| 1091 |
+
else:
|
| 1092 |
+
cache_name = possible_cache_name
|
| 1093 |
+
model_kwargs[cache_name] = getattr(outputs, possible_cache_name)
|
| 1094 |
+
break
|
| 1095 |
|
|
|
|
| 1096 |
if "attention_mask" in model_kwargs:
|
| 1097 |
attention_mask = model_kwargs["attention_mask"]
|
| 1098 |
model_kwargs["attention_mask"] = torch.cat(
|
| 1099 |
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
| 1100 |
)
|
|
|
|
|
|
|
| 1101 |
if "position_ids" in model_kwargs:
|
| 1102 |
position_ids = model_kwargs["position_ids"]
|
| 1103 |
new_position_id = position_ids[..., -1:].clone()
|
|
|
|
| 1105 |
model_kwargs["position_ids"] = torch.cat(
|
| 1106 |
[position_ids, new_position_id], dim=-1
|
| 1107 |
)
|
|
|
|
| 1108 |
model_kwargs["is_first_forward"] = False
|
| 1109 |
+
|
| 1110 |
+
if model_kwargs.get("use_cache", True) and "cache_position" in model_kwargs:
|
| 1111 |
+
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
|
| 1112 |
+
|
| 1113 |
return model_kwargs
|
| 1114 |
|
| 1115 |
def prepare_inputs_for_generation(
|