Lunaris-0.6B / generate.py
meryyllebr543's picture
Rename generate(1).py to generate.py
b2ad614 verified
# generate.py (VERSÃO CORRIGIDA COM SISTEMA DE PENALIDADE)
import torch
import argparse
from tokenizers import Tokenizer
from model import LunarisCodex, LunarisCodexConfig
from collections import defaultdict
import math
def unwrap_model_keys(state_dict):
unwrapped = {}
prefixes_to_remove = ['_orig_mod.module.', 'module.', '_orig_mod.']
for k, v in state_dict.items():
new_k = k
for prefix in prefixes_to_remove:
if new_k.startswith(prefix):
new_k = new_k[len(prefix):]
break
unwrapped[new_k] = v
return unwrapped
def apply_repetition_penalty(logits, input_ids, penalty=1.0):
"""
Aplica penalidade de repetição aos tokens que já apareceram na sequência.
"""
if penalty == 1.0:
return logits
# Conta a frequência de cada token
token_counts = defaultdict(int)
for token_id in input_ids.flatten():
token_counts[token_id.item()] += 1
# Aplica penalidade baseada na frequência
for token_id, count in token_counts.items():
if count > 0:
# Se logit é positivo, divide pela penalidade
# Se logit é negativo, multiplica pela penalidade
if logits[0, token_id] > 0:
logits[0, token_id] = logits[0, token_id] / penalty
else:
logits[0, token_id] = logits[0, token_id] * penalty
return logits
def apply_frequency_penalty(logits, input_ids, penalty=0.0):
"""
Aplica penalidade de frequência linear baseada no número de ocorrências.
"""
if penalty == 0.0:
return logits
# Conta a frequência de cada token
token_counts = defaultdict(int)
for token_id in input_ids.flatten():
token_counts[token_id.item()] += 1
# Aplica penalidade linear baseada na frequência
for token_id, count in token_counts.items():
if count > 0:
logits[0, token_id] = logits[0, token_id] - penalty * count
return logits
def apply_presence_penalty(logits, input_ids, penalty=0.0):
"""
Aplica penalidade de presença - penaliza tokens que já apareceram pelo menos uma vez.
"""
if penalty == 0.0:
return logits
# Identifica tokens únicos que já apareceram
unique_tokens = set(input_ids.flatten().tolist())
# Aplica penalidade fixa para tokens que já apareceram
for token_id in unique_tokens:
logits[0, token_id] = logits[0, token_id] - penalty
return logits
def apply_typical_sampling(logits, typical_p=1.0):
"""
Aplica Typical Sampling - mantém tokens com probabilidade "típica".
"""
if typical_p >= 1.0:
return logits
# Calcula probabilidades
probs = torch.softmax(logits, dim=-1)
# Calcula a entropia
entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)
# Calcula surpresa de cada token
surprisal = -torch.log(probs + 1e-10)
# Mantém apenas tokens com surpresa próxima da entropia
deviation = torch.abs(surprisal - entropy.unsqueeze(-1))
# Ordena por desvio e mantém apenas os típicos
sorted_indices = torch.argsort(deviation, dim=-1)
sorted_probs = probs.gather(-1, sorted_indices)
# Acumula probabilidades até atingir typical_p
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
# Encontra o cutoff
cutoff = torch.searchsorted(cumulative_probs, typical_p)
cutoff = torch.clamp(cutoff, min=1) # Mantém pelo menos 1 token
# Cria máscara para tokens típicos
mask = torch.zeros_like(logits, dtype=torch.bool)
for i in range(logits.size(0)):
typical_indices = sorted_indices[i, :cutoff[i]]
mask[i].scatter_(0, typical_indices, True)
# Aplica máscara
logits = logits.masked_fill(~mask, -float('inf'))
return logits
def safe_softmax_sampling(logits, temperature=1.0):
"""
Aplica softmax e amostragem de forma segura, evitando valores inválidos.
"""
# Evita temperature zero ou muito próximo de zero
if temperature <= 1e-5:
temperature = 1e-5
# Aplica temperatura
logits = logits / temperature
# Remove valores infinitos negativos extremos para evitar underflow
logits = torch.clamp(logits, min=-1e4, max=1e4)
# Aplica softmax
probs = torch.softmax(logits, dim=-1)
# Verifica se há valores inválidos
if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)) or torch.any(probs < 0):
# Fallback: usa distribuição uniforme
print("AVISO: Detectados valores inválidos nas probabilidades. Usando distribuição uniforme.")
probs = torch.ones_like(probs) / probs.size(-1)
# Garante que a soma é 1
probs = probs / probs.sum(dim=-1, keepdim=True)
return probs
def generate_with_penalties(model, idx, max_new_tokens, temperature=1.0, top_k=None,
repetition_penalty=1.0, frequency_penalty=0.0,
presence_penalty=0.0, typical_p=1.0, min_length=0):
"""
Geração de texto com sistema completo de penalidades e verificações de segurança.
"""
# Tokens especiais que podem indicar fim de sequência
eos_tokens = set([0, 1, 2]) # Ajuste conforme necessário
original_length = idx.size(1)
for step in range(max_new_tokens):
# Crop da sequência se necessário
idx_cond = idx if idx.size(1) <= model.config.max_seq_len else idx[:, -model.config.max_seq_len:]
# Forward do modelo
logits, _ = model(idx_cond)
logits = logits[:, -1, :] # Pega apenas o último token
# Aplica penalidades
logits = apply_repetition_penalty(logits, idx, repetition_penalty)
logits = apply_frequency_penalty(logits, idx, frequency_penalty)
logits = apply_presence_penalty(logits, idx, presence_penalty)
# Evita tokens de fim se ainda não atingiu o comprimento mínimo
current_length = idx.size(1)
if current_length - original_length < min_length:
for eos_token in eos_tokens:
if eos_token < logits.size(-1):
logits[0, eos_token] = -float('inf')
# Aplica typical sampling
logits = apply_typical_sampling(logits, typical_p)
# Aplica top-k
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('inf')
# Amostragem segura
probs = safe_softmax_sampling(logits, temperature)
# Amostra o próximo token
idx_next = torch.multinomial(probs, num_samples=1)
# Adiciona o token à sequência
idx = torch.cat((idx, idx_next), dim=1)
# Verifica se deve parar (opcional)
if idx_next.item() in eos_tokens and current_length - original_length >= min_length:
break
return idx
def main(args):
print("--- Iniciando Geração de Texto com Sistema de Penalidade ---")
torch.manual_seed(1337)
device = torch.device(args.device)
print(f"Usando dispositivo: {device}")
try:
print(f"Carregando checkpoint de: {args.checkpoint_path}")
checkpoint = torch.load(args.checkpoint_path, map_location=device, weights_only=False)
except FileNotFoundError:
print(f"ERRO: Arquivo de checkpoint não encontrado em '{args.checkpoint_path}'")
return
except Exception as e:
print(f"ERRO: Falha ao carregar o checkpoint: {e}")
return
# Carrega configuração e modelo
config = checkpoint['config']['model']
model = LunarisCodex(config)
unwrapped_state_dict = unwrap_model_keys(checkpoint['model'])
model.load_state_dict(unwrapped_state_dict)
model.to(device)
model.eval()
print(f"Carregando tokenizador de: {args.tokenizer_path}")
tokenizer = Tokenizer.from_file(args.tokenizer_path)
# Prepara entrada
start_ids = tokenizer.encode(args.prompt).ids
x = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]
print("\n" + "="*50)
print(f"Prompt: '{args.prompt}'")
print("Parâmetros de geração:")
print(f" Temperature: {args.temperature}")
print(f" Top-k: {args.top_k}")
print(f" Repetition penalty: {args.repetition_penalty}")
print(f" Frequency penalty: {args.frequency_penalty}")
print(f" Presence penalty: {args.presence_penalty}")
print(f" Typical-p: {args.typical_p}")
print(f" Min length: {args.min_length}")
print("Gerando texto...")
print("="*50)
with torch.no_grad():
y = generate_with_penalties(
model, x, args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
repetition_penalty=args.repetition_penalty,
frequency_penalty=args.frequency_penalty,
presence_penalty=args.presence_penalty,
typical_p=args.typical_p,
min_length=args.min_length
)
generated_text = tokenizer.decode(y[0].tolist())
print(generated_text)
print("\n--- Geração Concluída ---")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Gerar texto com sistema de penalidades do LunarisCodex.")
# Parâmetros básicos
parser.add_argument('--checkpoint_path', type=str, required=True,
help='Caminho para o arquivo .pt do checkpoint.')
parser.add_argument('--tokenizer_path', type=str, default='./lunaris-ultrafineweb-tokenizer.json',
help='Caminho para o arquivo do tokenizador.')
parser.add_argument('--prompt', type=str, default='The first step to build a rocket is',
help='O texto inicial para o modelo completar.')
parser.add_argument('--max_new_tokens', type=int, default=50,
help='Número de novos tokens a serem gerados.')
parser.add_argument('--device', type=str, default='cpu',
help='Dispositivo para rodar a geração (ex: "cpu" ou "cuda").')
# Parâmetros de controle de geração
parser.add_argument('--temperature', type=float, default=0.8,
help='Controla a aleatoriedade. Valores mais altos = mais criativo.')
parser.add_argument('--top_k', type=int, default=200,
help='Considera apenas os k tokens mais prováveis para amostragem.')
# Parâmetros de penalidade
parser.add_argument('--repetition_penalty', type=float, default=1.1,
help='Penalidade de repetição. 1.0 = sem penalidade, >1.0 = penaliza repetições.')
parser.add_argument('--frequency_penalty', type=float, default=0.0,
help='Penalidade de frequência. 0.0 = sem penalidade, >0.0 = penaliza tokens frequentes.')
parser.add_argument('--presence_penalty', type=float, default=0.0,
help='Penalidade de presença. 0.0 = sem penalidade, >0.0 = penaliza tokens já usados.')
parser.add_argument('--typical_p', type=float, default=1.0,
help='Typical sampling. 1.0 = desabilitado, <1.0 = mantém tokens típicos.')
parser.add_argument('--min_length', type=int, default=0,
help='Comprimento mínimo antes de permitir tokens de fim.')
args = parser.parse_args()
main(args)