|
|
|
|
|
|
|
|
import torch |
|
|
import argparse |
|
|
from tokenizers import Tokenizer |
|
|
from model import LunarisCodex, LunarisCodexConfig |
|
|
from collections import defaultdict |
|
|
import math |
|
|
|
|
|
def unwrap_model_keys(state_dict): |
|
|
unwrapped = {} |
|
|
prefixes_to_remove = ['_orig_mod.module.', 'module.', '_orig_mod.'] |
|
|
for k, v in state_dict.items(): |
|
|
new_k = k |
|
|
for prefix in prefixes_to_remove: |
|
|
if new_k.startswith(prefix): |
|
|
new_k = new_k[len(prefix):] |
|
|
break |
|
|
unwrapped[new_k] = v |
|
|
return unwrapped |
|
|
|
|
|
def apply_repetition_penalty(logits, input_ids, penalty=1.0): |
|
|
""" |
|
|
Aplica penalidade de repetição aos tokens que já apareceram na sequência. |
|
|
""" |
|
|
if penalty == 1.0: |
|
|
return logits |
|
|
|
|
|
|
|
|
token_counts = defaultdict(int) |
|
|
for token_id in input_ids.flatten(): |
|
|
token_counts[token_id.item()] += 1 |
|
|
|
|
|
|
|
|
for token_id, count in token_counts.items(): |
|
|
if count > 0: |
|
|
|
|
|
|
|
|
if logits[0, token_id] > 0: |
|
|
logits[0, token_id] = logits[0, token_id] / penalty |
|
|
else: |
|
|
logits[0, token_id] = logits[0, token_id] * penalty |
|
|
|
|
|
return logits |
|
|
|
|
|
def apply_frequency_penalty(logits, input_ids, penalty=0.0): |
|
|
""" |
|
|
Aplica penalidade de frequência linear baseada no número de ocorrências. |
|
|
""" |
|
|
if penalty == 0.0: |
|
|
return logits |
|
|
|
|
|
|
|
|
token_counts = defaultdict(int) |
|
|
for token_id in input_ids.flatten(): |
|
|
token_counts[token_id.item()] += 1 |
|
|
|
|
|
|
|
|
for token_id, count in token_counts.items(): |
|
|
if count > 0: |
|
|
logits[0, token_id] = logits[0, token_id] - penalty * count |
|
|
|
|
|
return logits |
|
|
|
|
|
def apply_presence_penalty(logits, input_ids, penalty=0.0): |
|
|
""" |
|
|
Aplica penalidade de presença - penaliza tokens que já apareceram pelo menos uma vez. |
|
|
""" |
|
|
if penalty == 0.0: |
|
|
return logits |
|
|
|
|
|
|
|
|
unique_tokens = set(input_ids.flatten().tolist()) |
|
|
|
|
|
|
|
|
for token_id in unique_tokens: |
|
|
logits[0, token_id] = logits[0, token_id] - penalty |
|
|
|
|
|
return logits |
|
|
|
|
|
def apply_typical_sampling(logits, typical_p=1.0): |
|
|
""" |
|
|
Aplica Typical Sampling - mantém tokens com probabilidade "típica". |
|
|
""" |
|
|
if typical_p >= 1.0: |
|
|
return logits |
|
|
|
|
|
|
|
|
probs = torch.softmax(logits, dim=-1) |
|
|
|
|
|
|
|
|
entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1) |
|
|
|
|
|
|
|
|
surprisal = -torch.log(probs + 1e-10) |
|
|
|
|
|
|
|
|
deviation = torch.abs(surprisal - entropy.unsqueeze(-1)) |
|
|
|
|
|
|
|
|
sorted_indices = torch.argsort(deviation, dim=-1) |
|
|
sorted_probs = probs.gather(-1, sorted_indices) |
|
|
|
|
|
|
|
|
cumulative_probs = torch.cumsum(sorted_probs, dim=-1) |
|
|
|
|
|
|
|
|
cutoff = torch.searchsorted(cumulative_probs, typical_p) |
|
|
cutoff = torch.clamp(cutoff, min=1) |
|
|
|
|
|
|
|
|
mask = torch.zeros_like(logits, dtype=torch.bool) |
|
|
for i in range(logits.size(0)): |
|
|
typical_indices = sorted_indices[i, :cutoff[i]] |
|
|
mask[i].scatter_(0, typical_indices, True) |
|
|
|
|
|
|
|
|
logits = logits.masked_fill(~mask, -float('inf')) |
|
|
|
|
|
return logits |
|
|
|
|
|
def safe_softmax_sampling(logits, temperature=1.0): |
|
|
""" |
|
|
Aplica softmax e amostragem de forma segura, evitando valores inválidos. |
|
|
""" |
|
|
|
|
|
if temperature <= 1e-5: |
|
|
temperature = 1e-5 |
|
|
|
|
|
|
|
|
logits = logits / temperature |
|
|
|
|
|
|
|
|
logits = torch.clamp(logits, min=-1e4, max=1e4) |
|
|
|
|
|
|
|
|
probs = torch.softmax(logits, dim=-1) |
|
|
|
|
|
|
|
|
if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)) or torch.any(probs < 0): |
|
|
|
|
|
print("AVISO: Detectados valores inválidos nas probabilidades. Usando distribuição uniforme.") |
|
|
probs = torch.ones_like(probs) / probs.size(-1) |
|
|
|
|
|
|
|
|
probs = probs / probs.sum(dim=-1, keepdim=True) |
|
|
|
|
|
return probs |
|
|
|
|
|
def generate_with_penalties(model, idx, max_new_tokens, temperature=1.0, top_k=None, |
|
|
repetition_penalty=1.0, frequency_penalty=0.0, |
|
|
presence_penalty=0.0, typical_p=1.0, min_length=0): |
|
|
""" |
|
|
Geração de texto com sistema completo de penalidades e verificações de segurança. |
|
|
""" |
|
|
|
|
|
eos_tokens = set([0, 1, 2]) |
|
|
|
|
|
original_length = idx.size(1) |
|
|
|
|
|
for step in range(max_new_tokens): |
|
|
|
|
|
idx_cond = idx if idx.size(1) <= model.config.max_seq_len else idx[:, -model.config.max_seq_len:] |
|
|
|
|
|
|
|
|
logits, _ = model(idx_cond) |
|
|
logits = logits[:, -1, :] |
|
|
|
|
|
|
|
|
logits = apply_repetition_penalty(logits, idx, repetition_penalty) |
|
|
logits = apply_frequency_penalty(logits, idx, frequency_penalty) |
|
|
logits = apply_presence_penalty(logits, idx, presence_penalty) |
|
|
|
|
|
|
|
|
current_length = idx.size(1) |
|
|
if current_length - original_length < min_length: |
|
|
for eos_token in eos_tokens: |
|
|
if eos_token < logits.size(-1): |
|
|
logits[0, eos_token] = -float('inf') |
|
|
|
|
|
|
|
|
logits = apply_typical_sampling(logits, typical_p) |
|
|
|
|
|
|
|
|
if top_k is not None: |
|
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
|
|
logits[logits < v[:, [-1]]] = -float('inf') |
|
|
|
|
|
|
|
|
probs = safe_softmax_sampling(logits, temperature) |
|
|
|
|
|
|
|
|
idx_next = torch.multinomial(probs, num_samples=1) |
|
|
|
|
|
|
|
|
idx = torch.cat((idx, idx_next), dim=1) |
|
|
|
|
|
|
|
|
if idx_next.item() in eos_tokens and current_length - original_length >= min_length: |
|
|
break |
|
|
|
|
|
return idx |
|
|
|
|
|
def main(args): |
|
|
print("--- Iniciando Geração de Texto com Sistema de Penalidade ---") |
|
|
|
|
|
torch.manual_seed(1337) |
|
|
device = torch.device(args.device) |
|
|
print(f"Usando dispositivo: {device}") |
|
|
|
|
|
try: |
|
|
print(f"Carregando checkpoint de: {args.checkpoint_path}") |
|
|
checkpoint = torch.load(args.checkpoint_path, map_location=device, weights_only=False) |
|
|
except FileNotFoundError: |
|
|
print(f"ERRO: Arquivo de checkpoint não encontrado em '{args.checkpoint_path}'") |
|
|
return |
|
|
except Exception as e: |
|
|
print(f"ERRO: Falha ao carregar o checkpoint: {e}") |
|
|
return |
|
|
|
|
|
|
|
|
config = checkpoint['config']['model'] |
|
|
model = LunarisCodex(config) |
|
|
|
|
|
unwrapped_state_dict = unwrap_model_keys(checkpoint['model']) |
|
|
model.load_state_dict(unwrapped_state_dict) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
print(f"Carregando tokenizador de: {args.tokenizer_path}") |
|
|
tokenizer = Tokenizer.from_file(args.tokenizer_path) |
|
|
|
|
|
|
|
|
start_ids = tokenizer.encode(args.prompt).ids |
|
|
x = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...] |
|
|
|
|
|
print("\n" + "="*50) |
|
|
print(f"Prompt: '{args.prompt}'") |
|
|
print("Parâmetros de geração:") |
|
|
print(f" Temperature: {args.temperature}") |
|
|
print(f" Top-k: {args.top_k}") |
|
|
print(f" Repetition penalty: {args.repetition_penalty}") |
|
|
print(f" Frequency penalty: {args.frequency_penalty}") |
|
|
print(f" Presence penalty: {args.presence_penalty}") |
|
|
print(f" Typical-p: {args.typical_p}") |
|
|
print(f" Min length: {args.min_length}") |
|
|
print("Gerando texto...") |
|
|
print("="*50) |
|
|
|
|
|
with torch.no_grad(): |
|
|
y = generate_with_penalties( |
|
|
model, x, args.max_new_tokens, |
|
|
temperature=args.temperature, |
|
|
top_k=args.top_k, |
|
|
repetition_penalty=args.repetition_penalty, |
|
|
frequency_penalty=args.frequency_penalty, |
|
|
presence_penalty=args.presence_penalty, |
|
|
typical_p=args.typical_p, |
|
|
min_length=args.min_length |
|
|
) |
|
|
|
|
|
generated_text = tokenizer.decode(y[0].tolist()) |
|
|
print(generated_text) |
|
|
print("\n--- Geração Concluída ---") |
|
|
|
|
|
if __name__ == '__main__': |
|
|
parser = argparse.ArgumentParser(description="Gerar texto com sistema de penalidades do LunarisCodex.") |
|
|
|
|
|
|
|
|
parser.add_argument('--checkpoint_path', type=str, required=True, |
|
|
help='Caminho para o arquivo .pt do checkpoint.') |
|
|
parser.add_argument('--tokenizer_path', type=str, default='./lunaris-ultrafineweb-tokenizer.json', |
|
|
help='Caminho para o arquivo do tokenizador.') |
|
|
parser.add_argument('--prompt', type=str, default='The first step to build a rocket is', |
|
|
help='O texto inicial para o modelo completar.') |
|
|
parser.add_argument('--max_new_tokens', type=int, default=50, |
|
|
help='Número de novos tokens a serem gerados.') |
|
|
parser.add_argument('--device', type=str, default='cpu', |
|
|
help='Dispositivo para rodar a geração (ex: "cpu" ou "cuda").') |
|
|
|
|
|
|
|
|
parser.add_argument('--temperature', type=float, default=0.8, |
|
|
help='Controla a aleatoriedade. Valores mais altos = mais criativo.') |
|
|
parser.add_argument('--top_k', type=int, default=200, |
|
|
help='Considera apenas os k tokens mais prováveis para amostragem.') |
|
|
|
|
|
|
|
|
parser.add_argument('--repetition_penalty', type=float, default=1.1, |
|
|
help='Penalidade de repetição. 1.0 = sem penalidade, >1.0 = penaliza repetições.') |
|
|
parser.add_argument('--frequency_penalty', type=float, default=0.0, |
|
|
help='Penalidade de frequência. 0.0 = sem penalidade, >0.0 = penaliza tokens frequentes.') |
|
|
parser.add_argument('--presence_penalty', type=float, default=0.0, |
|
|
help='Penalidade de presença. 0.0 = sem penalidade, >0.0 = penaliza tokens já usados.') |
|
|
parser.add_argument('--typical_p', type=float, default=1.0, |
|
|
help='Typical sampling. 1.0 = desabilitado, <1.0 = mantém tokens típicos.') |
|
|
parser.add_argument('--min_length', type=int, default=0, |
|
|
help='Comprimento mínimo antes de permitir tokens de fim.') |
|
|
|
|
|
args = parser.parse_args() |
|
|
main(args) |
|
|
|