# generate.py (VERSÃO CORRIGIDA COM SISTEMA DE PENALIDADE) import torch import argparse from tokenizers import Tokenizer from model import LunarisCodex, LunarisCodexConfig from collections import defaultdict import math def unwrap_model_keys(state_dict): unwrapped = {} prefixes_to_remove = ['_orig_mod.module.', 'module.', '_orig_mod.'] for k, v in state_dict.items(): new_k = k for prefix in prefixes_to_remove: if new_k.startswith(prefix): new_k = new_k[len(prefix):] break unwrapped[new_k] = v return unwrapped def apply_repetition_penalty(logits, input_ids, penalty=1.0): """ Aplica penalidade de repetição aos tokens que já apareceram na sequência. """ if penalty == 1.0: return logits # Conta a frequência de cada token token_counts = defaultdict(int) for token_id in input_ids.flatten(): token_counts[token_id.item()] += 1 # Aplica penalidade baseada na frequência for token_id, count in token_counts.items(): if count > 0: # Se logit é positivo, divide pela penalidade # Se logit é negativo, multiplica pela penalidade if logits[0, token_id] > 0: logits[0, token_id] = logits[0, token_id] / penalty else: logits[0, token_id] = logits[0, token_id] * penalty return logits def apply_frequency_penalty(logits, input_ids, penalty=0.0): """ Aplica penalidade de frequência linear baseada no número de ocorrências. """ if penalty == 0.0: return logits # Conta a frequência de cada token token_counts = defaultdict(int) for token_id in input_ids.flatten(): token_counts[token_id.item()] += 1 # Aplica penalidade linear baseada na frequência for token_id, count in token_counts.items(): if count > 0: logits[0, token_id] = logits[0, token_id] - penalty * count return logits def apply_presence_penalty(logits, input_ids, penalty=0.0): """ Aplica penalidade de presença - penaliza tokens que já apareceram pelo menos uma vez. """ if penalty == 0.0: return logits # Identifica tokens únicos que já apareceram unique_tokens = set(input_ids.flatten().tolist()) # Aplica penalidade fixa para tokens que já apareceram for token_id in unique_tokens: logits[0, token_id] = logits[0, token_id] - penalty return logits def apply_typical_sampling(logits, typical_p=1.0): """ Aplica Typical Sampling - mantém tokens com probabilidade "típica". """ if typical_p >= 1.0: return logits # Calcula probabilidades probs = torch.softmax(logits, dim=-1) # Calcula a entropia entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1) # Calcula surpresa de cada token surprisal = -torch.log(probs + 1e-10) # Mantém apenas tokens com surpresa próxima da entropia deviation = torch.abs(surprisal - entropy.unsqueeze(-1)) # Ordena por desvio e mantém apenas os típicos sorted_indices = torch.argsort(deviation, dim=-1) sorted_probs = probs.gather(-1, sorted_indices) # Acumula probabilidades até atingir typical_p cumulative_probs = torch.cumsum(sorted_probs, dim=-1) # Encontra o cutoff cutoff = torch.searchsorted(cumulative_probs, typical_p) cutoff = torch.clamp(cutoff, min=1) # Mantém pelo menos 1 token # Cria máscara para tokens típicos mask = torch.zeros_like(logits, dtype=torch.bool) for i in range(logits.size(0)): typical_indices = sorted_indices[i, :cutoff[i]] mask[i].scatter_(0, typical_indices, True) # Aplica máscara logits = logits.masked_fill(~mask, -float('inf')) return logits def safe_softmax_sampling(logits, temperature=1.0): """ Aplica softmax e amostragem de forma segura, evitando valores inválidos. """ # Evita temperature zero ou muito próximo de zero if temperature <= 1e-5: temperature = 1e-5 # Aplica temperatura logits = logits / temperature # Remove valores infinitos negativos extremos para evitar underflow logits = torch.clamp(logits, min=-1e4, max=1e4) # Aplica softmax probs = torch.softmax(logits, dim=-1) # Verifica se há valores inválidos if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)) or torch.any(probs < 0): # Fallback: usa distribuição uniforme print("AVISO: Detectados valores inválidos nas probabilidades. Usando distribuição uniforme.") probs = torch.ones_like(probs) / probs.size(-1) # Garante que a soma é 1 probs = probs / probs.sum(dim=-1, keepdim=True) return probs def generate_with_penalties(model, idx, max_new_tokens, temperature=1.0, top_k=None, repetition_penalty=1.0, frequency_penalty=0.0, presence_penalty=0.0, typical_p=1.0, min_length=0): """ Geração de texto com sistema completo de penalidades e verificações de segurança. """ # Tokens especiais que podem indicar fim de sequência eos_tokens = set([0, 1, 2]) # Ajuste conforme necessário original_length = idx.size(1) for step in range(max_new_tokens): # Crop da sequência se necessário idx_cond = idx if idx.size(1) <= model.config.max_seq_len else idx[:, -model.config.max_seq_len:] # Forward do modelo logits, _ = model(idx_cond) logits = logits[:, -1, :] # Pega apenas o último token # Aplica penalidades logits = apply_repetition_penalty(logits, idx, repetition_penalty) logits = apply_frequency_penalty(logits, idx, frequency_penalty) logits = apply_presence_penalty(logits, idx, presence_penalty) # Evita tokens de fim se ainda não atingiu o comprimento mínimo current_length = idx.size(1) if current_length - original_length < min_length: for eos_token in eos_tokens: if eos_token < logits.size(-1): logits[0, eos_token] = -float('inf') # Aplica typical sampling logits = apply_typical_sampling(logits, typical_p) # Aplica top-k if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('inf') # Amostragem segura probs = safe_softmax_sampling(logits, temperature) # Amostra o próximo token idx_next = torch.multinomial(probs, num_samples=1) # Adiciona o token à sequência idx = torch.cat((idx, idx_next), dim=1) # Verifica se deve parar (opcional) if idx_next.item() in eos_tokens and current_length - original_length >= min_length: break return idx def main(args): print("--- Iniciando Geração de Texto com Sistema de Penalidade ---") torch.manual_seed(1337) device = torch.device(args.device) print(f"Usando dispositivo: {device}") try: print(f"Carregando checkpoint de: {args.checkpoint_path}") checkpoint = torch.load(args.checkpoint_path, map_location=device, weights_only=False) except FileNotFoundError: print(f"ERRO: Arquivo de checkpoint não encontrado em '{args.checkpoint_path}'") return except Exception as e: print(f"ERRO: Falha ao carregar o checkpoint: {e}") return # Carrega configuração e modelo config = checkpoint['config']['model'] model = LunarisCodex(config) unwrapped_state_dict = unwrap_model_keys(checkpoint['model']) model.load_state_dict(unwrapped_state_dict) model.to(device) model.eval() print(f"Carregando tokenizador de: {args.tokenizer_path}") tokenizer = Tokenizer.from_file(args.tokenizer_path) # Prepara entrada start_ids = tokenizer.encode(args.prompt).ids x = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...] print("\n" + "="*50) print(f"Prompt: '{args.prompt}'") print("Parâmetros de geração:") print(f" Temperature: {args.temperature}") print(f" Top-k: {args.top_k}") print(f" Repetition penalty: {args.repetition_penalty}") print(f" Frequency penalty: {args.frequency_penalty}") print(f" Presence penalty: {args.presence_penalty}") print(f" Typical-p: {args.typical_p}") print(f" Min length: {args.min_length}") print("Gerando texto...") print("="*50) with torch.no_grad(): y = generate_with_penalties( model, x, args.max_new_tokens, temperature=args.temperature, top_k=args.top_k, repetition_penalty=args.repetition_penalty, frequency_penalty=args.frequency_penalty, presence_penalty=args.presence_penalty, typical_p=args.typical_p, min_length=args.min_length ) generated_text = tokenizer.decode(y[0].tolist()) print(generated_text) print("\n--- Geração Concluída ---") if __name__ == '__main__': parser = argparse.ArgumentParser(description="Gerar texto com sistema de penalidades do LunarisCodex.") # Parâmetros básicos parser.add_argument('--checkpoint_path', type=str, required=True, help='Caminho para o arquivo .pt do checkpoint.') parser.add_argument('--tokenizer_path', type=str, default='./lunaris-ultrafineweb-tokenizer.json', help='Caminho para o arquivo do tokenizador.') parser.add_argument('--prompt', type=str, default='The first step to build a rocket is', help='O texto inicial para o modelo completar.') parser.add_argument('--max_new_tokens', type=int, default=50, help='Número de novos tokens a serem gerados.') parser.add_argument('--device', type=str, default='cpu', help='Dispositivo para rodar a geração (ex: "cpu" ou "cuda").') # Parâmetros de controle de geração parser.add_argument('--temperature', type=float, default=0.8, help='Controla a aleatoriedade. Valores mais altos = mais criativo.') parser.add_argument('--top_k', type=int, default=200, help='Considera apenas os k tokens mais prováveis para amostragem.') # Parâmetros de penalidade parser.add_argument('--repetition_penalty', type=float, default=1.1, help='Penalidade de repetição. 1.0 = sem penalidade, >1.0 = penaliza repetições.') parser.add_argument('--frequency_penalty', type=float, default=0.0, help='Penalidade de frequência. 0.0 = sem penalidade, >0.0 = penaliza tokens frequentes.') parser.add_argument('--presence_penalty', type=float, default=0.0, help='Penalidade de presença. 0.0 = sem penalidade, >0.0 = penaliza tokens já usados.') parser.add_argument('--typical_p', type=float, default=1.0, help='Typical sampling. 1.0 = desabilitado, <1.0 = mantém tokens típicos.') parser.add_argument('--min_length', type=int, default=0, help='Comprimento mínimo antes de permitir tokens de fim.') args = parser.parse_args() main(args)