| from transformers import ( | |
| PretrainedConfig, | |
| PreTrainedModel, | |
| AutoConfig, | |
| AutoModel, | |
| AutoModelForCausalLM, | |
| ) | |
| from torch import nn | |
| from hf.llama import CustomAttentionLLaMa | |
| class MyLLaMaConfig(PretrainedConfig): | |
| model_type = "LLaMa" | |
| def __init__( | |
| self, | |
| embed_dim: int = 1536, | |
| n_layers: int = 24, | |
| n_heads: int = 24, | |
| n_chckpnt_segments: int = 24, | |
| **kwargs, | |
| ): | |
| self.embed_dim = embed_dim | |
| self.n_layers = n_layers | |
| self.n_heads = n_heads | |
| self.n_chckpnt_segments = n_chckpnt_segments | |
| super().__init__(**kwargs) | |
| class MyLLaMa(PreTrainedModel): | |
| config_class = MyLLaMaConfig | |
| def __init__(self, config: MyLLaMaConfig): | |
| super().__init__(config) | |
| self.model = CustomAttentionLLaMa( | |
| config.embed_dim, | |
| config.n_layers, | |
| config.n_heads, | |
| dropout=0, | |
| n_chckpnt_segments=config.n_chckpnt_segments, | |
| ) | |
| def forward(self, tensor, labels=None): | |
| logits = self.model(tensor)["logits"] | |
| if labels is not None: | |
| loss = nn.functional.cross_entropy(logits, labels) | |
| return {"loss": loss, "logits": logits} | |
| return {"logits": logits} | |
| AutoConfig.register("LLaMa", MyLLaMaConfig) | |
| AutoModel.register(MyLLaMaConfig, MyLLaMa) | |
| AutoModelForCausalLM.register(MyLLaMaConfig, MyLLaMa) | |