Upload modeling_chatglm.py
#2
by
bigmoyan
- opened
- modeling_chatglm.py +16 -3
modeling_chatglm.py
CHANGED
|
@@ -827,7 +827,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 827 |
init_method = default_init
|
| 828 |
init_kwargs = {}
|
| 829 |
if device is not None:
|
| 830 |
-
init_kwargs["device"] = device
|
|
|
|
|
|
|
| 831 |
self.embedding = init_method(Embedding, config, **init_kwargs)
|
| 832 |
self.num_layers = config.num_layers
|
| 833 |
self.multi_query_group_num = config.multi_query_group_num
|
|
@@ -923,10 +925,17 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 923 |
outputs: ModelOutput,
|
| 924 |
model_kwargs: Dict[str, Any],
|
| 925 |
is_encoder_decoder: bool = False,
|
|
|
|
| 926 |
) -> Dict[str, Any]:
|
| 927 |
# update past_key_values
|
| 928 |
-
|
| 929 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 930 |
|
| 931 |
# update attention mask
|
| 932 |
if "attention_mask" in model_kwargs:
|
|
@@ -945,6 +954,10 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 945 |
)
|
| 946 |
|
| 947 |
model_kwargs["is_first_forward"] = False
|
|
|
|
|
|
|
|
|
|
|
|
|
| 948 |
return model_kwargs
|
| 949 |
|
| 950 |
def prepare_inputs_for_generation(
|
|
|
|
| 827 |
init_method = default_init
|
| 828 |
init_kwargs = {}
|
| 829 |
if device is not None:
|
| 830 |
+
init_kwargs["device"] = torch.device(device)
|
| 831 |
+
if isinstance(config.torch_dtype, str):
|
| 832 |
+
config.torch_dtype = getattr(torch, config.torch_dtype)
|
| 833 |
self.embedding = init_method(Embedding, config, **init_kwargs)
|
| 834 |
self.num_layers = config.num_layers
|
| 835 |
self.multi_query_group_num = config.multi_query_group_num
|
|
|
|
| 925 |
outputs: ModelOutput,
|
| 926 |
model_kwargs: Dict[str, Any],
|
| 927 |
is_encoder_decoder: bool = False,
|
| 928 |
+
num_new_tokens: int = 1,
|
| 929 |
) -> Dict[str, Any]:
|
| 930 |
# update past_key_values
|
| 931 |
+
for possible_cache_name in ["past_key_values", "mems", "past_buckets_states", "cache_params"]:
|
| 932 |
+
if hasattr(outputs, possible_cache_name):
|
| 933 |
+
if possible_cache_name in ("past_buckets_states", "mems"):
|
| 934 |
+
cache_name = "past_key_values"
|
| 935 |
+
else:
|
| 936 |
+
cache_name = possible_cache_name
|
| 937 |
+
model_kwargs[cache_name] = getattr(outputs, possible_cache_name)
|
| 938 |
+
break
|
| 939 |
|
| 940 |
# update attention mask
|
| 941 |
if "attention_mask" in model_kwargs:
|
|
|
|
| 954 |
)
|
| 955 |
|
| 956 |
model_kwargs["is_first_forward"] = False
|
| 957 |
+
|
| 958 |
+
if model_kwargs.get("use_cache", True) and "cache_position" in model_kwargs:
|
| 959 |
+
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
|
| 960 |
+
|
| 961 |
return model_kwargs
|
| 962 |
|
| 963 |
def prepare_inputs_for_generation(
|