Add files using upload-large-folder tool
Browse files- .hydra/config.yaml +183 -0
- .hydra/hydra.yaml +154 -0
- .hydra/overrides.yaml +1 -0
- seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/README.md +207 -0
- seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/agent_adapter/adapter_config.json +42 -0
- seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/critic_adapter/adapter_config.json +42 -0
- src_code_for_reproducibility/__init__.py +0 -0
- src_code_for_reproducibility/chat_utils/__pycache__/apply_template.cpython-312.pyc +0 -0
- src_code_for_reproducibility/chat_utils/__pycache__/chat_turn.cpython-312.pyc +0 -0
- src_code_for_reproducibility/chat_utils/__pycache__/template_specific.cpython-312.pyc +0 -0
- src_code_for_reproducibility/docs/source/contributing.rst +0 -0
- src_code_for_reproducibility/docs/source/environments/diplomacy.rst +459 -0
- src_code_for_reproducibility/docs/source/environments/dond.rst +410 -0
- src_code_for_reproducibility/docs/source/environments/ipd.rst +411 -0
- src_code_for_reproducibility/docs/source/launch.rst +0 -0
- src_code_for_reproducibility/docs/source/media/runbatch.png +0 -0
- src_code_for_reproducibility/docs/source/modules.rst +7 -0
- src_code_for_reproducibility/docs/source/src.generation.run_games.rst +7 -0
- src_code_for_reproducibility/docs/source/src.models.dummy_hf_agent.rst +7 -0
- src_code_for_reproducibility/docs/source/src.models.dummy_local_llm.rst +7 -0
- src_code_for_reproducibility/docs/source/src.models.hf_agent.rst +7 -0
- src_code_for_reproducibility/docs/source/src.models.local_llm.rst +7 -0
- src_code_for_reproducibility/docs/source/src.models.oai_agent.rst +7 -0
- src_code_for_reproducibility/docs/source/src.models.server_llm.rst +7 -0
- src_code_for_reproducibility/docs/source/src.run.rst +7 -0
- src_code_for_reproducibility/docs/source/src.training.ppo_train.rst +7 -0
- src_code_for_reproducibility/docs/source/src.training.train_main.rst +7 -0
- src_code_for_reproducibility/docs/source/src.utils.common_imports.rst +7 -0
- src_code_for_reproducibility/docs/source/src.utils.extra_stats.rst +7 -0
- src_code_for_reproducibility/docs/source/src.utils.inherit_args.rst +7 -0
- src_code_for_reproducibility/docs/source/src.utils.log_gpu_usage.rst +7 -0
- src_code_for_reproducibility/docs/source/src.utils.log_statistics.rst +7 -0
- src_code_for_reproducibility/docs/source/src.utils.model_to_cpu.rst +7 -0
- src_code_for_reproducibility/docs/source/src.utils.quick_stats.rst +7 -0
- src_code_for_reproducibility/docs/source/usage.rst +0 -0
- src_code_for_reproducibility/models/human_policy.py +255 -0
- src_code_for_reproducibility/training/__init__.py +0 -0
- src_code_for_reproducibility/training/tally_tokenwise.py +276 -0
- src_code_for_reproducibility/training/tokenize_chats.py +128 -0
- src_code_for_reproducibility/training/trainer_ad_align.py +492 -0
- src_code_for_reproducibility/training/trainer_common.py +1054 -0
- src_code_for_reproducibility/training/trainer_independent.py +155 -0
- src_code_for_reproducibility/training/training_data_utils.py +394 -0
- src_code_for_reproducibility/utils/get_stochastic_game_lengths.py +30 -0
- src_code_for_reproducibility/utils/resource_context.py +78 -0
- src_code_for_reproducibility/utils/rollout_tree_chat_htmls.py +1921 -0
- src_code_for_reproducibility/utils/rollout_tree_stats.py +50 -0
- src_code_for_reproducibility/utils/short_id_gen.py +11 -0
- src_code_for_reproducibility/utils/stat_pack.py +113 -0
- src_code_for_reproducibility/utils/update_start_epoch.py +9 -0
.hydra/config.yaml
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
experiment:
|
| 2 |
+
wandb_enabled: true
|
| 3 |
+
nb_epochs: 3000
|
| 4 |
+
nb_matches_per_iteration: 64
|
| 5 |
+
reinit_matches_each_it: true
|
| 6 |
+
checkpoint_every_n_iterations: 10
|
| 7 |
+
start_epoch: 0
|
| 8 |
+
resume_experiment: true
|
| 9 |
+
base_seed: 0
|
| 10 |
+
seed_group_size: 8
|
| 11 |
+
train: true
|
| 12 |
+
stat_methods_for_live_wandb: mllm.markov_games.negotiation.negotiation_statistics
|
| 13 |
+
name: tas_rps_startend_ad_align_nocurrtimestep_beta2
|
| 14 |
+
agent_buffer: true
|
| 15 |
+
keep_agent_buffer_count: ${lora_count}
|
| 16 |
+
agent_buffer_recent_k: -1
|
| 17 |
+
description: Trust-and-Split Rock Paper Scissors negotiation game
|
| 18 |
+
logging:
|
| 19 |
+
wandb:
|
| 20 |
+
enabled: false
|
| 21 |
+
project: llm-negotiation
|
| 22 |
+
entity: null
|
| 23 |
+
mode: online
|
| 24 |
+
name: null
|
| 25 |
+
group: null
|
| 26 |
+
tags: []
|
| 27 |
+
notes: null
|
| 28 |
+
temperature: 1.0
|
| 29 |
+
markov_games:
|
| 30 |
+
runner_method_name: LinearRunner
|
| 31 |
+
runner_kwargs: {}
|
| 32 |
+
group_by_round: true
|
| 33 |
+
simulation_class_name: TrustAndSplitRPSSimulation
|
| 34 |
+
simulation_init_args:
|
| 35 |
+
nb_of_rounds: 10
|
| 36 |
+
quota_messages_per_agent_per_round: 1
|
| 37 |
+
alternating_hands: false
|
| 38 |
+
agents:
|
| 39 |
+
0:
|
| 40 |
+
agent_id: ${agent_0_id}
|
| 41 |
+
agent_name: Alice
|
| 42 |
+
agent_class_name: TrustAndSplitRPSAgent
|
| 43 |
+
policy_id: base_llm/agent_adapter
|
| 44 |
+
init_kwargs:
|
| 45 |
+
goal: Maximize your total points over the whole game.
|
| 46 |
+
num_message_chars: 500
|
| 47 |
+
message_start_end_format: true
|
| 48 |
+
proposal_start_end_format: true
|
| 49 |
+
1:
|
| 50 |
+
agent_id: ${agent_1_id}
|
| 51 |
+
agent_name: Bob
|
| 52 |
+
agent_class_name: TrustAndSplitRPSAgent
|
| 53 |
+
policy_id: base_llm/agent_adapter
|
| 54 |
+
init_kwargs:
|
| 55 |
+
goal: Maximize your total points over the whole game.
|
| 56 |
+
num_message_chars: 500
|
| 57 |
+
message_start_end_format: true
|
| 58 |
+
proposal_start_end_format: true
|
| 59 |
+
models:
|
| 60 |
+
base_llm:
|
| 61 |
+
class: LeanLocalLLM
|
| 62 |
+
init_args:
|
| 63 |
+
llm_id: base_llm
|
| 64 |
+
model_name: Qwen/Qwen2.5-7B-Instruct
|
| 65 |
+
inference_backend: vllm
|
| 66 |
+
hf_kwargs:
|
| 67 |
+
device_map: auto
|
| 68 |
+
torch_dtype: bfloat16
|
| 69 |
+
max_memory:
|
| 70 |
+
0: 20GiB
|
| 71 |
+
attn_implementation: flash_attention_2
|
| 72 |
+
inference_backend_init_kwargs:
|
| 73 |
+
enable_lora: true
|
| 74 |
+
seed: ${experiment.base_seed}
|
| 75 |
+
enable_prefix_caching: true
|
| 76 |
+
max_model_len: 10000.0
|
| 77 |
+
gpu_memory_utilization: 0.5
|
| 78 |
+
dtype: bfloat16
|
| 79 |
+
trust_remote_code: true
|
| 80 |
+
max_lora_rank: 32
|
| 81 |
+
enforce_eager: false
|
| 82 |
+
max_loras: ${lora_count}
|
| 83 |
+
max_cpu_loras: ${lora_count}
|
| 84 |
+
enable_sleep_mode: true
|
| 85 |
+
inference_backend_sampling_params:
|
| 86 |
+
temperature: ${temperature}
|
| 87 |
+
top_p: 1.0
|
| 88 |
+
max_tokens: 400
|
| 89 |
+
top_k: -1
|
| 90 |
+
logprobs: 0
|
| 91 |
+
adapter_configs:
|
| 92 |
+
agent_adapter:
|
| 93 |
+
task_type: CAUSAL_LM
|
| 94 |
+
r: 32
|
| 95 |
+
lora_alpha: 64
|
| 96 |
+
lora_dropout: 0.0
|
| 97 |
+
target_modules: all-linear
|
| 98 |
+
critic_adapter:
|
| 99 |
+
task_type: CAUSAL_LM
|
| 100 |
+
r: 32
|
| 101 |
+
lora_alpha: 64
|
| 102 |
+
lora_dropout: 0.0
|
| 103 |
+
target_modules: all-linear
|
| 104 |
+
enable_thinking: null
|
| 105 |
+
regex_max_attempts: 1
|
| 106 |
+
critics:
|
| 107 |
+
agent_critic:
|
| 108 |
+
module_pointer:
|
| 109 |
+
- base_llm
|
| 110 |
+
- critic_adapter
|
| 111 |
+
optimizers:
|
| 112 |
+
agent_optimizer:
|
| 113 |
+
module_pointer:
|
| 114 |
+
- base_llm
|
| 115 |
+
- agent_adapter
|
| 116 |
+
optimizer_class_name: torch.optim.Adam
|
| 117 |
+
init_args:
|
| 118 |
+
lr: 3.0e-06
|
| 119 |
+
weight_decay: 0.0
|
| 120 |
+
critic_optimizer:
|
| 121 |
+
module_pointer: agent_critic
|
| 122 |
+
optimizer_class_name: torch.optim.Adam
|
| 123 |
+
init_args:
|
| 124 |
+
lr: 3.0e-06
|
| 125 |
+
weight_decay: 0.0
|
| 126 |
+
trainers:
|
| 127 |
+
agent_trainer:
|
| 128 |
+
class: TrainerAdAlign
|
| 129 |
+
module_pointers:
|
| 130 |
+
policy:
|
| 131 |
+
- base_llm
|
| 132 |
+
- agent_adapter
|
| 133 |
+
policy_optimizer: agent_optimizer
|
| 134 |
+
critic: agent_critic
|
| 135 |
+
critic_optimizer: critic_optimizer
|
| 136 |
+
kwargs:
|
| 137 |
+
entropy_coeff: 0.0
|
| 138 |
+
entropy_topk: null
|
| 139 |
+
entropy_mask_regex: null
|
| 140 |
+
kl_coeff: 0.001
|
| 141 |
+
gradient_clipping: 1.0
|
| 142 |
+
restrict_tokens: null
|
| 143 |
+
mini_batch_size: 1
|
| 144 |
+
use_gradient_checkpointing: true
|
| 145 |
+
temperature: ${temperature}
|
| 146 |
+
device: cuda:0
|
| 147 |
+
use_gae: false
|
| 148 |
+
whiten_advantages: false
|
| 149 |
+
whiten_advantages_time_step_wise: false
|
| 150 |
+
skip_discounted_state_visitation: true
|
| 151 |
+
use_gae_lambda_annealing: false
|
| 152 |
+
gae_lambda_annealing_method: None
|
| 153 |
+
gae_lambda_annealing_method_params: None
|
| 154 |
+
gae_lambda_annealing_limit: 0.95
|
| 155 |
+
discount_factor: 0.96
|
| 156 |
+
use_rloo: true
|
| 157 |
+
enable_tokenwise_logging: false
|
| 158 |
+
pg_loss_normalization: nb_tokens
|
| 159 |
+
truncated_importance_sampling_ratio_cap: 2.0
|
| 160 |
+
reward_normalizing_constant: 100.0
|
| 161 |
+
ad_align_force_coop_first_step: false
|
| 162 |
+
ad_align_clipping: null
|
| 163 |
+
ad_align_gamma: 0.96
|
| 164 |
+
ad_align_exclude_k_equals_t: true
|
| 165 |
+
ad_align_use_sign: false
|
| 166 |
+
ad_align_beta: 2.0
|
| 167 |
+
use_old_ad_align: true
|
| 168 |
+
use_time_regularization: false
|
| 169 |
+
rloo_branch: false
|
| 170 |
+
reuse_baseline: false
|
| 171 |
+
train_on_which_data:
|
| 172 |
+
agent_trainer: ${agent_ids}
|
| 173 |
+
lora_count: 30
|
| 174 |
+
common_agent_kwargs:
|
| 175 |
+
goal: Maximize your total points over the whole game.
|
| 176 |
+
num_message_chars: 500
|
| 177 |
+
message_start_end_format: true
|
| 178 |
+
proposal_start_end_format: true
|
| 179 |
+
agent_0_id: Alice
|
| 180 |
+
agent_1_id: Bob
|
| 181 |
+
agent_ids:
|
| 182 |
+
- Alice
|
| 183 |
+
- Bob
|
.hydra/hydra.yaml
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hydra:
|
| 2 |
+
run:
|
| 3 |
+
dir: ${oc.env:SCRATCH}/llm_negotiation/${now:%Y_%m}/${experiment.name}
|
| 4 |
+
sweep:
|
| 5 |
+
dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
| 6 |
+
subdir: ${hydra.job.num}
|
| 7 |
+
launcher:
|
| 8 |
+
_target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
|
| 9 |
+
sweeper:
|
| 10 |
+
_target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
|
| 11 |
+
max_batch_size: null
|
| 12 |
+
params: null
|
| 13 |
+
help:
|
| 14 |
+
app_name: ${hydra.job.name}
|
| 15 |
+
header: '${hydra.help.app_name} is powered by Hydra.
|
| 16 |
+
|
| 17 |
+
'
|
| 18 |
+
footer: 'Powered by Hydra (https://hydra.cc)
|
| 19 |
+
|
| 20 |
+
Use --hydra-help to view Hydra specific help
|
| 21 |
+
|
| 22 |
+
'
|
| 23 |
+
template: '${hydra.help.header}
|
| 24 |
+
|
| 25 |
+
== Configuration groups ==
|
| 26 |
+
|
| 27 |
+
Compose your configuration from those groups (group=option)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
$APP_CONFIG_GROUPS
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
== Config ==
|
| 34 |
+
|
| 35 |
+
Override anything in the config (foo.bar=value)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
$CONFIG
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
${hydra.help.footer}
|
| 42 |
+
|
| 43 |
+
'
|
| 44 |
+
hydra_help:
|
| 45 |
+
template: 'Hydra (${hydra.runtime.version})
|
| 46 |
+
|
| 47 |
+
See https://hydra.cc for more info.
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
== Flags ==
|
| 51 |
+
|
| 52 |
+
$FLAGS_HELP
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
== Configuration groups ==
|
| 56 |
+
|
| 57 |
+
Compose your configuration from those groups (For example, append hydra/job_logging=disabled
|
| 58 |
+
to command line)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
$HYDRA_CONFIG_GROUPS
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
Use ''--cfg hydra'' to Show the Hydra config.
|
| 65 |
+
|
| 66 |
+
'
|
| 67 |
+
hydra_help: ???
|
| 68 |
+
hydra_logging:
|
| 69 |
+
version: 1
|
| 70 |
+
formatters:
|
| 71 |
+
simple:
|
| 72 |
+
format: '[%(asctime)s][HYDRA] %(message)s'
|
| 73 |
+
handlers:
|
| 74 |
+
console:
|
| 75 |
+
class: logging.StreamHandler
|
| 76 |
+
formatter: simple
|
| 77 |
+
stream: ext://sys.stdout
|
| 78 |
+
root:
|
| 79 |
+
level: INFO
|
| 80 |
+
handlers:
|
| 81 |
+
- console
|
| 82 |
+
loggers:
|
| 83 |
+
logging_example:
|
| 84 |
+
level: DEBUG
|
| 85 |
+
disable_existing_loggers: false
|
| 86 |
+
job_logging:
|
| 87 |
+
version: 1
|
| 88 |
+
formatters:
|
| 89 |
+
simple:
|
| 90 |
+
format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
|
| 91 |
+
handlers:
|
| 92 |
+
console:
|
| 93 |
+
class: logging.StreamHandler
|
| 94 |
+
formatter: simple
|
| 95 |
+
stream: ext://sys.stdout
|
| 96 |
+
file:
|
| 97 |
+
class: logging.FileHandler
|
| 98 |
+
formatter: simple
|
| 99 |
+
filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
|
| 100 |
+
root:
|
| 101 |
+
level: INFO
|
| 102 |
+
handlers:
|
| 103 |
+
- console
|
| 104 |
+
- file
|
| 105 |
+
disable_existing_loggers: false
|
| 106 |
+
env: {}
|
| 107 |
+
mode: RUN
|
| 108 |
+
searchpath: []
|
| 109 |
+
callbacks: {}
|
| 110 |
+
output_subdir: .hydra
|
| 111 |
+
overrides:
|
| 112 |
+
hydra:
|
| 113 |
+
- hydra.mode=RUN
|
| 114 |
+
task: []
|
| 115 |
+
job:
|
| 116 |
+
name: run
|
| 117 |
+
chdir: false
|
| 118 |
+
override_dirname: ''
|
| 119 |
+
id: ???
|
| 120 |
+
num: ???
|
| 121 |
+
config_name: tas_rps_startend_ad_align_nocurrtimestep_beta2.yaml
|
| 122 |
+
env_set: {}
|
| 123 |
+
env_copy: []
|
| 124 |
+
config:
|
| 125 |
+
override_dirname:
|
| 126 |
+
kv_sep: '='
|
| 127 |
+
item_sep: ','
|
| 128 |
+
exclude_keys: []
|
| 129 |
+
runtime:
|
| 130 |
+
version: 1.3.2
|
| 131 |
+
version_base: '1.1'
|
| 132 |
+
cwd: /scratch/muqeeth/llm_negotiation
|
| 133 |
+
config_sources:
|
| 134 |
+
- path: hydra.conf
|
| 135 |
+
schema: pkg
|
| 136 |
+
provider: hydra
|
| 137 |
+
- path: /scratch/muqeeth/llm_negotiation/configs
|
| 138 |
+
schema: file
|
| 139 |
+
provider: main
|
| 140 |
+
- path: ''
|
| 141 |
+
schema: structured
|
| 142 |
+
provider: schema
|
| 143 |
+
output_dir: /scratch/muqeeth/llm_negotiation/2025_11/tas_rps_startend_ad_align_nocurrtimestep_beta2
|
| 144 |
+
choices:
|
| 145 |
+
hydra/env: default
|
| 146 |
+
hydra/callbacks: null
|
| 147 |
+
hydra/job_logging: default
|
| 148 |
+
hydra/hydra_logging: default
|
| 149 |
+
hydra/hydra_help: default
|
| 150 |
+
hydra/help: default
|
| 151 |
+
hydra/sweeper: basic
|
| 152 |
+
hydra/launcher: basic
|
| 153 |
+
hydra/output: default
|
| 154 |
+
verbose: false
|
.hydra/overrides.yaml
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
[]
|
seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/README.md
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
base_model: Qwen/Qwen2.5-7B-Instruct
|
| 3 |
+
library_name: peft
|
| 4 |
+
pipeline_tag: text-generation
|
| 5 |
+
tags:
|
| 6 |
+
- base_model:adapter:Qwen/Qwen2.5-7B-Instruct
|
| 7 |
+
- lora
|
| 8 |
+
- transformers
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# Model Card for Model ID
|
| 12 |
+
|
| 13 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
## Model Details
|
| 18 |
+
|
| 19 |
+
### Model Description
|
| 20 |
+
|
| 21 |
+
<!-- Provide a longer summary of what this model is. -->
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
- **Developed by:** [More Information Needed]
|
| 26 |
+
- **Funded by [optional]:** [More Information Needed]
|
| 27 |
+
- **Shared by [optional]:** [More Information Needed]
|
| 28 |
+
- **Model type:** [More Information Needed]
|
| 29 |
+
- **Language(s) (NLP):** [More Information Needed]
|
| 30 |
+
- **License:** [More Information Needed]
|
| 31 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
| 32 |
+
|
| 33 |
+
### Model Sources [optional]
|
| 34 |
+
|
| 35 |
+
<!-- Provide the basic links for the model. -->
|
| 36 |
+
|
| 37 |
+
- **Repository:** [More Information Needed]
|
| 38 |
+
- **Paper [optional]:** [More Information Needed]
|
| 39 |
+
- **Demo [optional]:** [More Information Needed]
|
| 40 |
+
|
| 41 |
+
## Uses
|
| 42 |
+
|
| 43 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 44 |
+
|
| 45 |
+
### Direct Use
|
| 46 |
+
|
| 47 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 48 |
+
|
| 49 |
+
[More Information Needed]
|
| 50 |
+
|
| 51 |
+
### Downstream Use [optional]
|
| 52 |
+
|
| 53 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 54 |
+
|
| 55 |
+
[More Information Needed]
|
| 56 |
+
|
| 57 |
+
### Out-of-Scope Use
|
| 58 |
+
|
| 59 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 60 |
+
|
| 61 |
+
[More Information Needed]
|
| 62 |
+
|
| 63 |
+
## Bias, Risks, and Limitations
|
| 64 |
+
|
| 65 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 66 |
+
|
| 67 |
+
[More Information Needed]
|
| 68 |
+
|
| 69 |
+
### Recommendations
|
| 70 |
+
|
| 71 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 72 |
+
|
| 73 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 74 |
+
|
| 75 |
+
## How to Get Started with the Model
|
| 76 |
+
|
| 77 |
+
Use the code below to get started with the model.
|
| 78 |
+
|
| 79 |
+
[More Information Needed]
|
| 80 |
+
|
| 81 |
+
## Training Details
|
| 82 |
+
|
| 83 |
+
### Training Data
|
| 84 |
+
|
| 85 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 86 |
+
|
| 87 |
+
[More Information Needed]
|
| 88 |
+
|
| 89 |
+
### Training Procedure
|
| 90 |
+
|
| 91 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 92 |
+
|
| 93 |
+
#### Preprocessing [optional]
|
| 94 |
+
|
| 95 |
+
[More Information Needed]
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
#### Training Hyperparameters
|
| 99 |
+
|
| 100 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 101 |
+
|
| 102 |
+
#### Speeds, Sizes, Times [optional]
|
| 103 |
+
|
| 104 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 105 |
+
|
| 106 |
+
[More Information Needed]
|
| 107 |
+
|
| 108 |
+
## Evaluation
|
| 109 |
+
|
| 110 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 111 |
+
|
| 112 |
+
### Testing Data, Factors & Metrics
|
| 113 |
+
|
| 114 |
+
#### Testing Data
|
| 115 |
+
|
| 116 |
+
<!-- This should link to a Dataset Card if possible. -->
|
| 117 |
+
|
| 118 |
+
[More Information Needed]
|
| 119 |
+
|
| 120 |
+
#### Factors
|
| 121 |
+
|
| 122 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 123 |
+
|
| 124 |
+
[More Information Needed]
|
| 125 |
+
|
| 126 |
+
#### Metrics
|
| 127 |
+
|
| 128 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 129 |
+
|
| 130 |
+
[More Information Needed]
|
| 131 |
+
|
| 132 |
+
### Results
|
| 133 |
+
|
| 134 |
+
[More Information Needed]
|
| 135 |
+
|
| 136 |
+
#### Summary
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
## Model Examination [optional]
|
| 141 |
+
|
| 142 |
+
<!-- Relevant interpretability work for the model goes here -->
|
| 143 |
+
|
| 144 |
+
[More Information Needed]
|
| 145 |
+
|
| 146 |
+
## Environmental Impact
|
| 147 |
+
|
| 148 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 149 |
+
|
| 150 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 151 |
+
|
| 152 |
+
- **Hardware Type:** [More Information Needed]
|
| 153 |
+
- **Hours used:** [More Information Needed]
|
| 154 |
+
- **Cloud Provider:** [More Information Needed]
|
| 155 |
+
- **Compute Region:** [More Information Needed]
|
| 156 |
+
- **Carbon Emitted:** [More Information Needed]
|
| 157 |
+
|
| 158 |
+
## Technical Specifications [optional]
|
| 159 |
+
|
| 160 |
+
### Model Architecture and Objective
|
| 161 |
+
|
| 162 |
+
[More Information Needed]
|
| 163 |
+
|
| 164 |
+
### Compute Infrastructure
|
| 165 |
+
|
| 166 |
+
[More Information Needed]
|
| 167 |
+
|
| 168 |
+
#### Hardware
|
| 169 |
+
|
| 170 |
+
[More Information Needed]
|
| 171 |
+
|
| 172 |
+
#### Software
|
| 173 |
+
|
| 174 |
+
[More Information Needed]
|
| 175 |
+
|
| 176 |
+
## Citation [optional]
|
| 177 |
+
|
| 178 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 179 |
+
|
| 180 |
+
**BibTeX:**
|
| 181 |
+
|
| 182 |
+
[More Information Needed]
|
| 183 |
+
|
| 184 |
+
**APA:**
|
| 185 |
+
|
| 186 |
+
[More Information Needed]
|
| 187 |
+
|
| 188 |
+
## Glossary [optional]
|
| 189 |
+
|
| 190 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 191 |
+
|
| 192 |
+
[More Information Needed]
|
| 193 |
+
|
| 194 |
+
## More Information [optional]
|
| 195 |
+
|
| 196 |
+
[More Information Needed]
|
| 197 |
+
|
| 198 |
+
## Model Card Authors [optional]
|
| 199 |
+
|
| 200 |
+
[More Information Needed]
|
| 201 |
+
|
| 202 |
+
## Model Card Contact
|
| 203 |
+
|
| 204 |
+
[More Information Needed]
|
| 205 |
+
### Framework versions
|
| 206 |
+
|
| 207 |
+
- PEFT 0.17.1
|
seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/agent_adapter/adapter_config.json
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"alpha_pattern": {},
|
| 3 |
+
"auto_mapping": null,
|
| 4 |
+
"base_model_name_or_path": "Qwen/Qwen2.5-7B-Instruct",
|
| 5 |
+
"bias": "none",
|
| 6 |
+
"corda_config": null,
|
| 7 |
+
"eva_config": null,
|
| 8 |
+
"exclude_modules": null,
|
| 9 |
+
"fan_in_fan_out": false,
|
| 10 |
+
"inference_mode": true,
|
| 11 |
+
"init_lora_weights": true,
|
| 12 |
+
"layer_replication": null,
|
| 13 |
+
"layers_pattern": null,
|
| 14 |
+
"layers_to_transform": null,
|
| 15 |
+
"loftq_config": {},
|
| 16 |
+
"lora_alpha": 64,
|
| 17 |
+
"lora_bias": false,
|
| 18 |
+
"lora_dropout": 0.0,
|
| 19 |
+
"megatron_config": null,
|
| 20 |
+
"megatron_core": "megatron.core",
|
| 21 |
+
"modules_to_save": null,
|
| 22 |
+
"peft_type": "LORA",
|
| 23 |
+
"qalora_group_size": 16,
|
| 24 |
+
"r": 32,
|
| 25 |
+
"rank_pattern": {},
|
| 26 |
+
"revision": null,
|
| 27 |
+
"target_modules": [
|
| 28 |
+
"up_proj",
|
| 29 |
+
"q_proj",
|
| 30 |
+
"v_proj",
|
| 31 |
+
"gate_proj",
|
| 32 |
+
"down_proj",
|
| 33 |
+
"k_proj",
|
| 34 |
+
"o_proj"
|
| 35 |
+
],
|
| 36 |
+
"target_parameters": null,
|
| 37 |
+
"task_type": "CAUSAL_LM",
|
| 38 |
+
"trainable_token_indices": null,
|
| 39 |
+
"use_dora": false,
|
| 40 |
+
"use_qalora": false,
|
| 41 |
+
"use_rslora": false
|
| 42 |
+
}
|
seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/critic_adapter/adapter_config.json
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"alpha_pattern": {},
|
| 3 |
+
"auto_mapping": null,
|
| 4 |
+
"base_model_name_or_path": "Qwen/Qwen2.5-7B-Instruct",
|
| 5 |
+
"bias": "none",
|
| 6 |
+
"corda_config": null,
|
| 7 |
+
"eva_config": null,
|
| 8 |
+
"exclude_modules": null,
|
| 9 |
+
"fan_in_fan_out": false,
|
| 10 |
+
"inference_mode": true,
|
| 11 |
+
"init_lora_weights": true,
|
| 12 |
+
"layer_replication": null,
|
| 13 |
+
"layers_pattern": null,
|
| 14 |
+
"layers_to_transform": null,
|
| 15 |
+
"loftq_config": {},
|
| 16 |
+
"lora_alpha": 64,
|
| 17 |
+
"lora_bias": false,
|
| 18 |
+
"lora_dropout": 0.0,
|
| 19 |
+
"megatron_config": null,
|
| 20 |
+
"megatron_core": "megatron.core",
|
| 21 |
+
"modules_to_save": null,
|
| 22 |
+
"peft_type": "LORA",
|
| 23 |
+
"qalora_group_size": 16,
|
| 24 |
+
"r": 32,
|
| 25 |
+
"rank_pattern": {},
|
| 26 |
+
"revision": null,
|
| 27 |
+
"target_modules": [
|
| 28 |
+
"up_proj",
|
| 29 |
+
"q_proj",
|
| 30 |
+
"v_proj",
|
| 31 |
+
"gate_proj",
|
| 32 |
+
"down_proj",
|
| 33 |
+
"k_proj",
|
| 34 |
+
"o_proj"
|
| 35 |
+
],
|
| 36 |
+
"target_parameters": null,
|
| 37 |
+
"task_type": "CAUSAL_LM",
|
| 38 |
+
"trainable_token_indices": null,
|
| 39 |
+
"use_dora": false,
|
| 40 |
+
"use_qalora": false,
|
| 41 |
+
"use_rslora": false
|
| 42 |
+
}
|
src_code_for_reproducibility/__init__.py
ADDED
|
File without changes
|
src_code_for_reproducibility/chat_utils/__pycache__/apply_template.cpython-312.pyc
ADDED
|
Binary file (3.92 kB). View file
|
|
|
src_code_for_reproducibility/chat_utils/__pycache__/chat_turn.cpython-312.pyc
ADDED
|
Binary file (1.32 kB). View file
|
|
|
src_code_for_reproducibility/chat_utils/__pycache__/template_specific.cpython-312.pyc
ADDED
|
Binary file (4.24 kB). View file
|
|
|
src_code_for_reproducibility/docs/source/contributing.rst
ADDED
|
File without changes
|
src_code_for_reproducibility/docs/source/environments/diplomacy.rst
ADDED
|
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
=================
|
| 2 |
+
Diplomacy
|
| 3 |
+
=================
|
| 4 |
+
|
| 5 |
+
The Diplomacy environment provides a multi-agent negotiation interface for the classic board game Diplomacy,
|
| 6 |
+
based on DeepMind's implementation. This document describes the API for interacting with the Diplomacy environment
|
| 7 |
+
and its associated agent handler.
|
| 8 |
+
|
| 9 |
+
Overview
|
| 10 |
+
--------
|
| 11 |
+
|
| 12 |
+
Diplomacy is a strategic board game set in Europe before World War I, where players control one of seven European powers
|
| 13 |
+
and negotiate with each other to gain control of supply centers. The game is played in turns, with each turn consisting
|
| 14 |
+
of movement phases, retreat phases, and build phases.
|
| 15 |
+
|
| 16 |
+
Our implementation adapts DeepMind's Diplomacy code to the Multi-Agent Negotiation Environment standard, allowing it
|
| 17 |
+
to be used with LLM agents through a text-based interface.
|
| 18 |
+
|
| 19 |
+
Game Rules
|
| 20 |
+
----------
|
| 21 |
+
|
| 22 |
+
### Game Board and Powers
|
| 23 |
+
|
| 24 |
+
Diplomacy is played on a map of Europe divided into provinces. The game features seven Great Powers that players can control:
|
| 25 |
+
|
| 26 |
+
- England (blue)
|
| 27 |
+
- France (light blue)
|
| 28 |
+
- Germany (black)
|
| 29 |
+
- Italy (green)
|
| 30 |
+
- Austria-Hungary (red)
|
| 31 |
+
- Russia (white)
|
| 32 |
+
- Turkey (yellow)
|
| 33 |
+
|
| 34 |
+
Each power begins with three supply centers (except Russia, which starts with four) and an equal number of units.
|
| 35 |
+
|
| 36 |
+
### Units and Movement
|
| 37 |
+
|
| 38 |
+
There are two types of units in Diplomacy:
|
| 39 |
+
- **Armies (A)**: Can move to adjacent land provinces or be convoyed across water by fleets
|
| 40 |
+
- **Fleets (F)**: Can move to adjacent coastal provinces and sea regions
|
| 41 |
+
|
| 42 |
+
During movement phases, each unit can execute one of these orders:
|
| 43 |
+
- **Hold**: The unit remains in its current province (e.g., "A PAR H")
|
| 44 |
+
- Format: [Unit Type] [Province] H
|
| 45 |
+
- Example: "A PAR H" means "Army in Paris holds its position"
|
| 46 |
+
|
| 47 |
+
- **Move**: The unit attempts to move to an adjacent province (e.g., "A PAR - BUR")
|
| 48 |
+
- Format: [Unit Type] [Current Province] - [Destination Province]
|
| 49 |
+
- Example: "A PAR - BUR" means "Army in Paris moves to Burgundy"
|
| 50 |
+
- Example: "F BRE - ENG" means "Fleet in Brest moves to the English Channel"
|
| 51 |
+
|
| 52 |
+
- **Support**: The unit supports another unit's move or hold (e.g., "A PAR S A MAR - BUR")
|
| 53 |
+
- Format for supporting a move: [Unit Type] [Province] S [Unit Type] [Province] - [Destination]
|
| 54 |
+
- Format for supporting a hold: [Unit Type] [Province] S [Unit Type] [Province]
|
| 55 |
+
- Example: "A PAR S A MAR - BUR" means "Army in Paris supports the Army in Marseille's move to Burgundy"
|
| 56 |
+
- Example: "F LON S F NTH" means "Fleet in London supports the Fleet in North Sea holding its position"
|
| 57 |
+
|
| 58 |
+
- **Convoy**: A fleet can convoy an army across water (e.g., "F ENG C A LON - BRE")
|
| 59 |
+
- Format: [Fleet] [Sea Province] C [Army] [Coastal Province] - [Coastal Province]
|
| 60 |
+
- Example: "F ENG C A LON - BRE" means "Fleet in English Channel convoys the Army in London to Brest"
|
| 61 |
+
|
| 62 |
+
All orders are executed simultaneously, and conflicts are resolved based on strength (number of supporting units).
|
| 63 |
+
|
| 64 |
+
### Common Province Abbreviations
|
| 65 |
+
|
| 66 |
+
Diplomacy uses three-letter abbreviations for provinces. Some common ones include:
|
| 67 |
+
- **PAR**: Paris
|
| 68 |
+
- **LON**: London
|
| 69 |
+
- **BER**: Berlin
|
| 70 |
+
- **MUN**: Munich
|
| 71 |
+
- **BUR**: Burgundy
|
| 72 |
+
- **MAR**: Marseilles
|
| 73 |
+
- **BRE**: Brest
|
| 74 |
+
- **ENG**: English Channel
|
| 75 |
+
- **NTH**: North Sea
|
| 76 |
+
- **VIE**: Vienna
|
| 77 |
+
- **ROM**: Rome
|
| 78 |
+
- **VEN**: Venice
|
| 79 |
+
- **MOW**: Moscow
|
| 80 |
+
- **CON**: Constantinople
|
| 81 |
+
|
| 82 |
+
### Example: Movement and Conflicts
|
| 83 |
+
|
| 84 |
+
For example, if France orders "A PAR - BUR" and Germany orders "A MUN - BUR", neither move succeeds as they have equal strength. However, if France also orders "A MAR S A PAR - BUR", then the French army from Paris would successfully move to Burgundy with strength of 2 against Germany's strength of 1.
|
| 85 |
+
|
| 86 |
+
### Turn Structure
|
| 87 |
+
|
| 88 |
+
A game year consists of five phases:
|
| 89 |
+
1. **Spring Movement**: All powers submit orders for their units
|
| 90 |
+
2. **Spring Retreat**: Units dislodged in the movement phase must retreat or be disbanded
|
| 91 |
+
3. **Fall Movement**: Another round of movement orders
|
| 92 |
+
4. **Fall Retreat**: Retreat orders for dislodged units
|
| 93 |
+
5. **Winter Adjustment**: Powers gain or lose units based on the number of supply centers they control
|
| 94 |
+
|
| 95 |
+
### Supply Centers and Building
|
| 96 |
+
|
| 97 |
+
Supply centers (marked on the map) are key to victory. When a power occupies a supply center during a Fall turn, they gain control of it. During the Winter Adjustment phase:
|
| 98 |
+
- If you control more supply centers than you have units, you can build new units in your home supply centers
|
| 99 |
+
- If you control fewer supply centers than you have units, you must remove excess units
|
| 100 |
+
|
| 101 |
+
### Example: Building and Removing Units
|
| 102 |
+
|
| 103 |
+
If France controls 5 supply centers but only has 4 units, during the Winter phase they can build one new unit in an unoccupied home supply center (Paris, Marseilles, or Brest). Conversely, if France controls only 3 supply centers but has 4 units, they must remove one unit of their choice.
|
| 104 |
+
|
| 105 |
+
### Negotiation
|
| 106 |
+
|
| 107 |
+
A critical component of Diplomacy is the negotiation between players. Before submitting orders, players can communicate freely to form alliances, coordinate attacks, or mislead opponents. These negotiations are not binding, and betrayal is a common strategy.
|
| 108 |
+
|
| 109 |
+
### Example: Alliance and Betrayal
|
| 110 |
+
|
| 111 |
+
England and France might agree to an alliance against Germany, with England promising to support France's move into Belgium. However, England could secretly order their fleet to move into Belgium themselves or support a German move instead.
|
| 112 |
+
|
| 113 |
+
### Victory Conditions
|
| 114 |
+
|
| 115 |
+
The game ends when one power controls 18 or more supply centers (majority of the 34 total centers), or when players agree to a draw. In tournament settings, games may also end after a predetermined number of game years.
|
| 116 |
+
|
| 117 |
+
DiplomacyEnv
|
| 118 |
+
------------
|
| 119 |
+
|
| 120 |
+
The ``DiplomacyEnv`` class provides an interface to the Diplomacy game environment that follows the Multi-Agent
|
| 121 |
+
Negotiation Environment standard.
|
| 122 |
+
|
| 123 |
+
.. code-block:: python
|
| 124 |
+
|
| 125 |
+
class DiplomacyEnv:
|
| 126 |
+
"""
|
| 127 |
+
Multi-Agent Negotiation Environment for Diplomacy, adapting Deepmind's implementation
|
| 128 |
+
to the MarlEnvironment standard.
|
| 129 |
+
"""
|
| 130 |
+
def __init__(self,
|
| 131 |
+
initial_state: Optional[DiplomacyState] = None,
|
| 132 |
+
max_turns: int = 100,
|
| 133 |
+
points_per_supply_centre: bool = True,
|
| 134 |
+
forced_draw_probability: float = 0.0,
|
| 135 |
+
min_years_forced_draw: int = 35):
|
| 136 |
+
"""Initialize the Diplomacy environment.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
initial_state: Initial DiplomacyState (optional)
|
| 140 |
+
max_turns: Maximum number of turns in the game
|
| 141 |
+
points_per_supply_centre: Whether to award points per supply center in case of a draw
|
| 142 |
+
forced_draw_probability: Probability of forcing a draw after min_years_forced_draw
|
| 143 |
+
min_years_forced_draw: Minimum years before considering a forced draw
|
| 144 |
+
"""
|
| 145 |
+
# ...
|
| 146 |
+
|
| 147 |
+
def reset(self):
|
| 148 |
+
"""Reset the environment to an initial state and return the initial observation.
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
observation (dict): A dictionary where keys are agent identifiers and values are observations.
|
| 152 |
+
Each observation contains:
|
| 153 |
+
- board_state: Current state of the board
|
| 154 |
+
- current_season: Current season in the game
|
| 155 |
+
- player_index: Index of the player's power
|
| 156 |
+
- possible_actions: List of possible actions in DeepMind's format
|
| 157 |
+
- human_readable_actions: List of human-readable action descriptions
|
| 158 |
+
- supply_centers: List of supply centers owned by the player
|
| 159 |
+
- units: List of units owned by the player
|
| 160 |
+
- year: Current year in the game
|
| 161 |
+
"""
|
| 162 |
+
# ...
|
| 163 |
+
|
| 164 |
+
def step(self, actions):
|
| 165 |
+
"""Take a step in the environment using the provided actions.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
actions (dict): A dictionary where keys are agent identifiers and values are actions.
|
| 169 |
+
Actions can be:
|
| 170 |
+
- List of integer actions in DeepMind's format
|
| 171 |
+
- List of string actions in text format (e.g., "A MUN - BER")
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
observations (dict): A dictionary where keys are agent identifiers and values are observations.
|
| 175 |
+
Each observation has the same structure as in reset().
|
| 176 |
+
done (bool): Whether the episode has ended.
|
| 177 |
+
info (dict): Additional information about the environment, including:
|
| 178 |
+
- turn: Current turn number
|
| 179 |
+
- returns: Game returns if the game is done, otherwise None
|
| 180 |
+
- waiting_for: List of agents that still need to provide actions (if not all actions are provided)
|
| 181 |
+
"""
|
| 182 |
+
# ...
|
| 183 |
+
|
| 184 |
+
def get_log_info(self):
|
| 185 |
+
"""Get additional information about the environment for logging.
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
log_info (dict): Information about the environment required to log the game, including:
|
| 189 |
+
- power_names: List of power names
|
| 190 |
+
- game_history: History of the game
|
| 191 |
+
- current_turn: Current turn number
|
| 192 |
+
- current_season: Current season name
|
| 193 |
+
- supply_centers: Dictionary mapping power names to supply center counts
|
| 194 |
+
"""
|
| 195 |
+
# ...
|
| 196 |
+
|
| 197 |
+
def render(self):
|
| 198 |
+
"""Render the current state of the environment.
|
| 199 |
+
|
| 200 |
+
Displays a visualization of the current game state.
|
| 201 |
+
"""
|
| 202 |
+
# ...
|
| 203 |
+
|
| 204 |
+
def close(self):
|
| 205 |
+
"""Perform any necessary cleanup."""
|
| 206 |
+
# ...
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
Key Implementation Details
|
| 210 |
+
~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 211 |
+
|
| 212 |
+
The ``DiplomacyEnv`` class implements several key features:
|
| 213 |
+
|
| 214 |
+
1. **Multi-Agent Support**: The environment tracks multiple agents (powers) and manages their interactions.
|
| 215 |
+
|
| 216 |
+
2. **Turn-Based Gameplay**: The environment enforces the turn structure of Diplomacy, including different phases.
|
| 217 |
+
|
| 218 |
+
3. **Action Processing**: The environment can handle actions in both text format and DeepMind's integer format.
|
| 219 |
+
|
| 220 |
+
4. **Observation Generation**: The environment generates detailed observations for each agent, including board state, supply centers, and possible actions.
|
| 221 |
+
|
| 222 |
+
5. **Game Termination**: The environment tracks game termination conditions, including supply center victory and maximum turn limits.
|
| 223 |
+
|
| 224 |
+
Observation Structure
|
| 225 |
+
~~~~~~~~~~~~~~~~~~~~
|
| 226 |
+
|
| 227 |
+
Each agent receives an observation dictionary with the following structure:
|
| 228 |
+
|
| 229 |
+
.. code-block:: python
|
| 230 |
+
|
| 231 |
+
{
|
| 232 |
+
"board_state": np.ndarray, # Board state representation
|
| 233 |
+
"current_season": int, # Season index (0-4)
|
| 234 |
+
"player_index": int, # Index of the player's power (0-6)
|
| 235 |
+
"possible_actions": [int], # List of possible actions in DeepMind's format
|
| 236 |
+
"human_readable_actions": [str], # List of human-readable action descriptions
|
| 237 |
+
"supply_centers": [str], # List of supply centers owned by the player
|
| 238 |
+
"units": [dict], # List of units owned by the player
|
| 239 |
+
"year": int # Current year in the game
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
Action Structure
|
| 243 |
+
~~~~~~~~~~~~~~~
|
| 244 |
+
|
| 245 |
+
Actions can be provided in two formats:
|
| 246 |
+
|
| 247 |
+
1. **Text Format**: String actions like ``"A MUN - BER"`` or ``"F NTH C A LON - BEL"``.
|
| 248 |
+
|
| 249 |
+
2. **Integer Format**: Lists of integers corresponding to DeepMind's action representation.
|
| 250 |
+
|
| 251 |
+
The environment will convert text actions to the internal format as needed.
|
| 252 |
+
|
| 253 |
+
DiplomacyAgent
|
| 254 |
+
--------------
|
| 255 |
+
|
| 256 |
+
The ``DiplomacyAgent`` class implements the agent handler interface for Diplomacy, processing observations from the environment and generating actions through an LLM.
|
| 257 |
+
|
| 258 |
+
.. code-block:: python
|
| 259 |
+
|
| 260 |
+
class DiplomacyAgent:
|
| 261 |
+
"""
|
| 262 |
+
Agent handler for Diplomacy, implementing the AgentState interface
|
| 263 |
+
for the multi-agent negotiation standard.
|
| 264 |
+
"""
|
| 265 |
+
|
| 266 |
+
def __init__(self,
|
| 267 |
+
power_name: str,
|
| 268 |
+
use_text_interface: bool = True,
|
| 269 |
+
system_prompt: Optional[str] = None):
|
| 270 |
+
"""Initialize the Diplomacy agent handler.
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
power_name: Name of the power this agent controls
|
| 274 |
+
use_text_interface: Whether to use text-based interface (vs. structured)
|
| 275 |
+
system_prompt: Optional system prompt to use for the LLM
|
| 276 |
+
"""
|
| 277 |
+
# ...
|
| 278 |
+
|
| 279 |
+
def step(self, observation_from_env, policy_output=None):
|
| 280 |
+
"""Update the agent state based on the observation and action.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
observation_from_env: The observation from the environment, with structure:
|
| 284 |
+
- board_state: Current state of the board
|
| 285 |
+
- current_season: Current season in the game
|
| 286 |
+
- player_index: Index of the player's power
|
| 287 |
+
- possible_actions: List of possible actions
|
| 288 |
+
- human_readable_actions: List of human-readable action descriptions
|
| 289 |
+
- supply_centers: List of supply centers owned by the player
|
| 290 |
+
- units: List of units owned by the player
|
| 291 |
+
- year: Current year in the game
|
| 292 |
+
|
| 293 |
+
policy_output: The output of the policy (LLM response), or None for initial prompt
|
| 294 |
+
|
| 295 |
+
Returns:
|
| 296 |
+
policy_id (str): The policy identifier ("llm_policy")
|
| 297 |
+
policy_input (dict): The input to the policy, with structure:
|
| 298 |
+
- messages: List of conversation messages in the format:
|
| 299 |
+
[{"role": "system", "content": "..."},
|
| 300 |
+
{"role": "user", "content": "..."}]
|
| 301 |
+
action: The official action to be sent to the environment, or None if not ready
|
| 302 |
+
done (bool): Whether the LLM action is ready to be sent to the environment
|
| 303 |
+
info (dict): Additional information about the agent:
|
| 304 |
+
- valid_action: Whether the extracted action is valid
|
| 305 |
+
"""
|
| 306 |
+
# ...
|
| 307 |
+
|
| 308 |
+
def get_log_info(self):
|
| 309 |
+
"""Get information about the agent required to log a trajectory.
|
| 310 |
+
|
| 311 |
+
Returns:
|
| 312 |
+
log_info (dict): Information about the agent required to log a trajectory:
|
| 313 |
+
- power_name: Name of the power this agent controls
|
| 314 |
+
- conversation_history: List of conversation messages
|
| 315 |
+
- current_action: The current action, if any
|
| 316 |
+
"""
|
| 317 |
+
# ...
|
| 318 |
+
|
| 319 |
+
def render(self):
|
| 320 |
+
"""Render the current state of the agent.
|
| 321 |
+
|
| 322 |
+
Displays the agent's current state, including conversation history.
|
| 323 |
+
"""
|
| 324 |
+
# ...
|
| 325 |
+
|
| 326 |
+
def close(self):
|
| 327 |
+
"""Perform any necessary cleanup."""
|
| 328 |
+
# ...
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
Key Implementation Details
|
| 332 |
+
~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 333 |
+
|
| 334 |
+
The ``DiplomacyAgent`` class implements several key features:
|
| 335 |
+
|
| 336 |
+
1. **LLM Interaction**: The agent generates prompts for an LLM and processes the LLM's responses to extract actions.
|
| 337 |
+
|
| 338 |
+
2. **Conversation Management**: The agent maintains a conversation history for coherent interactions with the LLM.
|
| 339 |
+
|
| 340 |
+
3. **Action Validation**: The agent validates extracted actions against the set of possible actions provided by the environment.
|
| 341 |
+
|
| 342 |
+
4. **Error Handling**: The agent generates clarification prompts when invalid actions are detected.
|
| 343 |
+
|
| 344 |
+
5. **Text-Based Interface**: The agent formats game state information into human-readable text for the LLM.
|
| 345 |
+
|
| 346 |
+
Prompt Structure
|
| 347 |
+
~~~~~~~~~~~~~~~
|
| 348 |
+
|
| 349 |
+
The agent generates prompts that include:
|
| 350 |
+
|
| 351 |
+
1. **System Prompt**: Instructions and context for the LLM, explaining its role as a Diplomacy player.
|
| 352 |
+
|
| 353 |
+
2. **Game State Description**: A text description of the current game state, including:
|
| 354 |
+
- Current year and season
|
| 355 |
+
- Supply centers owned
|
| 356 |
+
- Units controlled
|
| 357 |
+
- Possible actions
|
| 358 |
+
|
| 359 |
+
3. **Action Request**: Instructions on how to format actions.
|
| 360 |
+
|
| 361 |
+
Example system prompt:
|
| 362 |
+
|
| 363 |
+
.. code-block:: text
|
| 364 |
+
|
| 365 |
+
You are playing the role of FRANCE in a game of Diplomacy.
|
| 366 |
+
Your goal is to control as many supply centers as possible.
|
| 367 |
+
You can negotiate with other players and form alliances, but remember that
|
| 368 |
+
these alliances are not binding. When you need to submit orders for your units,
|
| 369 |
+
write them in the correct format, with each order on a new line.
|
| 370 |
+
|
| 371 |
+
Example game state description:
|
| 372 |
+
|
| 373 |
+
.. code-block:: text
|
| 374 |
+
|
| 375 |
+
Year: 1901, Season: SPRING_MOVES
|
| 376 |
+
You are playing as FRANCE.
|
| 377 |
+
You currently control 3 supply centers: PAR, MAR, BRE.
|
| 378 |
+
Your units are: A PAR, A MAR, F BRE.
|
| 379 |
+
|
| 380 |
+
Please provide orders for your units. Here are your possible actions:
|
| 381 |
+
A PAR - BUR
|
| 382 |
+
A PAR - GAS
|
| 383 |
+
A PAR - PIC
|
| 384 |
+
A PAR H
|
| 385 |
+
...
|
| 386 |
+
|
| 387 |
+
Submit your orders, one per line, in the format like: "A MUN - BER" or "F NTH C A LON - BEL"
|
| 388 |
+
|
| 389 |
+
Running Diplomacy Games
|
| 390 |
+
----------------------
|
| 391 |
+
|
| 392 |
+
To run Diplomacy games with LLM agents, you can use the ``run_batched_matches`` function with the ``DiplomacyEnv`` and ``DiplomacyAgent`` classes:
|
| 393 |
+
|
| 394 |
+
.. code-block:: python
|
| 395 |
+
|
| 396 |
+
from mllm.environments.diplomacy.diplomacy_env import DiplomacyEnv
|
| 397 |
+
from mllm.environments.diplomacy.diplomacy_agent import DiplomacyAgent
|
| 398 |
+
from mllm.run_matches import run_batched_matches
|
| 399 |
+
|
| 400 |
+
# Create environment and agent handlers
|
| 401 |
+
env = DiplomacyEnv(max_turns=30)
|
| 402 |
+
|
| 403 |
+
agent_handlers = {
|
| 404 |
+
"AUSTRIA": DiplomacyAgent(power_name="AUSTRIA"),
|
| 405 |
+
"ENGLAND": DiplomacyAgent(power_name="ENGLAND"),
|
| 406 |
+
"FRANCE": DiplomacyAgent(power_name="FRANCE"),
|
| 407 |
+
"GERMANY": DiplomacyAgent(power_name="GERMANY"),
|
| 408 |
+
"ITALY": DiplomacyAgent(power_name="ITALY"),
|
| 409 |
+
"RUSSIA": DiplomacyAgent(power_name="RUSSIA"),
|
| 410 |
+
"TURKEY": DiplomacyAgent(power_name="TURKEY")
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
# Define policy mapping (mapping from policy IDs to actual policy functions)
|
| 414 |
+
policy_mapping = {
|
| 415 |
+
"llm_policy": my_llm_policy_function
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
# Run the game
|
| 419 |
+
game_results = run_batched_matches(
|
| 420 |
+
envs=[env],
|
| 421 |
+
agent_handlers_per_env=[agent_handlers],
|
| 422 |
+
policy_mapping=policy_mapping,
|
| 423 |
+
max_parallel_matches=1
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
# Process results
|
| 427 |
+
for result in game_results:
|
| 428 |
+
print(f"Game finished. Winner: {result['winner']}")
|
| 429 |
+
print(f"Supply centers: {result['supply_centers']}")
|
| 430 |
+
|
| 431 |
+
This setup allows you to run Diplomacy games with LLM agents using the Multi-Agent Negotiation Environment standard.
|
| 432 |
+
|
| 433 |
+
Limitations and Considerations
|
| 434 |
+
-----------------------------
|
| 435 |
+
|
| 436 |
+
1. **Performance**: Processing observations and actions for seven powers using LLMs can be computationally intensive.
|
| 437 |
+
|
| 438 |
+
2. **Action Parsing**: Extracting valid actions from LLM outputs may require sophisticated parsing and error handling.
|
| 439 |
+
|
| 440 |
+
3. **Game Complexity**: Diplomacy is a complex game with many rules and edge cases, which may be challenging for LLMs to fully grasp.
|
| 441 |
+
|
| 442 |
+
4. **Turn Duration**: Real Diplomacy games include negotiation phases of variable duration, which are not fully captured in this implementation.
|
| 443 |
+
|
| 444 |
+
5. **Text Formatting**: The quality of LLM interactions depends heavily on the formatting and clarity of text prompts.
|
| 445 |
+
|
| 446 |
+
Advanced Usage
|
| 447 |
+
------------
|
| 448 |
+
|
| 449 |
+
For advanced usage, you can customize:
|
| 450 |
+
|
| 451 |
+
1. **System Prompts**: Modify agent behavior by providing custom system prompts.
|
| 452 |
+
|
| 453 |
+
2. **Observation Processing**: Extend the observation processing to include additional information.
|
| 454 |
+
|
| 455 |
+
3. **Action Parsing**: Implement more sophisticated action parsing for complex orders.
|
| 456 |
+
|
| 457 |
+
4. **Visualization**: Add custom visualization methods to the environment's render function.
|
| 458 |
+
|
| 459 |
+
5. **Logging**: Extend the logging capabilities to capture additional information about the game state.
|
src_code_for_reproducibility/docs/source/environments/dond.rst
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
=================
|
| 2 |
+
Deal or No Deal
|
| 3 |
+
=================
|
| 4 |
+
|
| 5 |
+
The Deal or No Deal (DoND) environment provides a multi-agent negotiation interface where players trade
|
| 6 |
+
items with different values. This document describes the API for interacting with the DoND environment
|
| 7 |
+
and its associated agent handler.
|
| 8 |
+
|
| 9 |
+
Overview
|
| 10 |
+
--------
|
| 11 |
+
|
| 12 |
+
Deal or No Deal is a negotiation game where two agents must agree on how to divide a set of items,
|
| 13 |
+
each of which has different values to each agent. The agents engage in a back-and-forth dialogue to
|
| 14 |
+
determine an allocation of the items, with each trying to maximize their own total value.
|
| 15 |
+
|
| 16 |
+
Our implementation follows the Multi-Agent Negotiation Environment standard, allowing it to be used
|
| 17 |
+
with LLM agents through a text-based interface.
|
| 18 |
+
|
| 19 |
+
Game Rules
|
| 20 |
+
----------
|
| 21 |
+
|
| 22 |
+
### Basic Structure
|
| 23 |
+
|
| 24 |
+
The core mechanics of Deal or No Deal are:
|
| 25 |
+
|
| 26 |
+
1. Two agents negotiate over a set of items (e.g., books, balls, hats)
|
| 27 |
+
2. Each item has:
|
| 28 |
+
- A specific quantity (how many of each item is available)
|
| 29 |
+
- A value for each agent (which may differ between agents)
|
| 30 |
+
3. Agents take turns sending messages to negotiate how to split the items
|
| 31 |
+
4. Once an agreement is reached, agents finalize the deal
|
| 32 |
+
5. Points are awarded based on the value of items each agent receives
|
| 33 |
+
|
| 34 |
+
### Detailed Gameplay
|
| 35 |
+
|
| 36 |
+
#### Setup Phase
|
| 37 |
+
|
| 38 |
+
The game begins with:
|
| 39 |
+
- A set of items (e.g., "book", "hat", "ball")
|
| 40 |
+
- Each item has a quantity (e.g., 6 books, 2 hats, 4 balls)
|
| 41 |
+
- Each agent has private values for each item (e.g., books might be worth 5 points to one agent but only 2 points to the other)
|
| 42 |
+
- Agents are assigned roles (starting negotiator and responding negotiator)
|
| 43 |
+
|
| 44 |
+
#### Negotiation Phase
|
| 45 |
+
|
| 46 |
+
1. Agents take turns sending free-form text messages to each other
|
| 47 |
+
2. Messages can include offers, counter-offers, questions, or strategic communication
|
| 48 |
+
3. There is a maximum number of messages permitted (preventing endless negotiations)
|
| 49 |
+
4. Either agent can propose to finalize an agreement at any time
|
| 50 |
+
|
| 51 |
+
For example:
|
| 52 |
+
- Agent 1: "I propose I get all the books and you get all the hats and balls."
|
| 53 |
+
- Agent 2: "That doesn't work for me. How about you get 3 books and I get 3 books, all the hats, and all the balls?"
|
| 54 |
+
- Agent 1: "Let me counter-offer: I get 4 books and 2 balls, you get 2 books, all hats, and 2 balls."
|
| 55 |
+
|
| 56 |
+
#### Finalization Phase
|
| 57 |
+
|
| 58 |
+
1. When an agent wants to finalize a deal, they must specify the exact allocation:
|
| 59 |
+
- How many of each item they receive
|
| 60 |
+
- How many of each item the other agent receives
|
| 61 |
+
2. The other agent must then either agree (by submitting the same allocation) or reject the finalization
|
| 62 |
+
3. If both agents submit matching finalizations, the deal is executed
|
| 63 |
+
4. If finalizations don't match, no agreement is reached, and both agents receive 0 points
|
| 64 |
+
|
| 65 |
+
#### Scoring
|
| 66 |
+
|
| 67 |
+
1. Each agent's score is calculated based on the value of items they receive
|
| 68 |
+
2. The formula is: Sum(quantity_of_item_i × value_of_item_i_to_agent)
|
| 69 |
+
3. If no agreement is reached, both agents receive 0 points
|
| 70 |
+
|
| 71 |
+
### Example Game
|
| 72 |
+
|
| 73 |
+
Let's walk through a simple example:
|
| 74 |
+
|
| 75 |
+
**Setup:**
|
| 76 |
+
- Items: Books (4), Hats (2), Balls (6)
|
| 77 |
+
- Agent 1 values: Books=5, Hats=1, Balls=2
|
| 78 |
+
- Agent 2 values: Books=3, Hats=6, Balls=1
|
| 79 |
+
|
| 80 |
+
**Negotiation (simplified):**
|
| 81 |
+
1. Agent 1: "I would like all the books and balls. You can have the hats."
|
| 82 |
+
2. Agent 2: "That doesn't work for me. Books are valuable. I propose I get all the hats and 2 books, you get 2 books and all the balls."
|
| 83 |
+
3. Agent 1: "How about I get 3 books and all the balls, and you get 1 book and all the hats?"
|
| 84 |
+
4. Agent 2: "I accept your proposal."
|
| 85 |
+
|
| 86 |
+
**Finalization:**
|
| 87 |
+
- Agent 1 submits: Agent 1 gets (Books: 3, Hats: 0, Balls: 6), Agent 2 gets (Books: 1, Hats: 2, Balls: 0)
|
| 88 |
+
- Agent 2 submits the same allocation, confirming agreement
|
| 89 |
+
|
| 90 |
+
**Scoring:**
|
| 91 |
+
- Agent 1 score: (3 books × 5) + (0 hats × 1) + (6 balls × 2) = 15 + 0 + 12 = 27 points
|
| 92 |
+
- Agent 2 score: (1 book × 3) + (2 hats × 6) + (0 balls × 1) = 3 + 12 + 0 = 15 points
|
| 93 |
+
|
| 94 |
+
### Game Variations
|
| 95 |
+
|
| 96 |
+
The DoND environment supports several variations through configuration parameters:
|
| 97 |
+
|
| 98 |
+
#### Different Value Distributions
|
| 99 |
+
|
| 100 |
+
The environment offers multiple ways to assign values to items:
|
| 101 |
+
|
| 102 |
+
1. **Standard Random Setup (dond_random_setup)**:
|
| 103 |
+
- Items have even-numbered quantities
|
| 104 |
+
- Each agent receives distinct random values for each item
|
| 105 |
+
- Values are drawn from a uniform distribution
|
| 106 |
+
|
| 107 |
+
2. **Independent Random Values (independent_random_vals)**:
|
| 108 |
+
- Item quantities can be any number in the specified range
|
| 109 |
+
- Values for each agent are drawn independently
|
| 110 |
+
- Creates more varied negotiation scenarios
|
| 111 |
+
|
| 112 |
+
3. **Bicameral Value Distribution (bicameral_vals_assignator)**:
|
| 113 |
+
- Creates a "high value" and "low value" distribution for each item
|
| 114 |
+
- Each agent values approximately half the items highly and half lowly
|
| 115 |
+
- Values are drawn from normal distributions with different means
|
| 116 |
+
- Creates scenarios with clear trade opportunities
|
| 117 |
+
|
| 118 |
+
#### Visibility Options
|
| 119 |
+
|
| 120 |
+
1. **Finalization Visibility**:
|
| 121 |
+
- When enabled, both agents can see each other's finalization proposals
|
| 122 |
+
- When disabled, finalization proposals remain private until both are submitted
|
| 123 |
+
|
| 124 |
+
2. **Other Values Visibility**:
|
| 125 |
+
- When enabled, agents can see each other's value functions
|
| 126 |
+
- When disabled, agents only know their own values
|
| 127 |
+
- Creates information asymmetry and richer negotiation dynamics
|
| 128 |
+
|
| 129 |
+
#### Game Modes
|
| 130 |
+
|
| 131 |
+
1. **Cooperative Mode ("coop")**:
|
| 132 |
+
- Agents are encouraged to find mutually beneficial solutions
|
| 133 |
+
- Success is measured by the sum of both agents' scores
|
| 134 |
+
|
| 135 |
+
2. **Competitive Mode ("comp")**:
|
| 136 |
+
- Agents aim to maximize their individual scores
|
| 137 |
+
- Creates more adversarial negotiations
|
| 138 |
+
|
| 139 |
+
#### Round Structure
|
| 140 |
+
|
| 141 |
+
1. **Single Round**:
|
| 142 |
+
- One negotiation session between the same agents
|
| 143 |
+
- Simple evaluation of negotiation skills
|
| 144 |
+
|
| 145 |
+
2. **Multiple Rounds**:
|
| 146 |
+
- Agents negotiate multiple times with different item setups
|
| 147 |
+
- Allows for learning and adaptation over time
|
| 148 |
+
- Roles can be swapped between rounds
|
| 149 |
+
|
| 150 |
+
DondEnv
|
| 151 |
+
------------
|
| 152 |
+
|
| 153 |
+
The ``DondEnv`` class provides an interface to the Deal or No Deal environment that follows the Multi-Agent
|
| 154 |
+
Negotiation Environment standard.
|
| 155 |
+
|
| 156 |
+
.. code-block:: python
|
| 157 |
+
|
| 158 |
+
class DondEnv:
|
| 159 |
+
"""
|
| 160 |
+
Multi-Agent Negotiation Environment for Deal or No Deal.
|
| 161 |
+
"""
|
| 162 |
+
def __init__(
|
| 163 |
+
self,
|
| 164 |
+
agents,
|
| 165 |
+
mode="coop",
|
| 166 |
+
max_messages=None,
|
| 167 |
+
min_messages=None,
|
| 168 |
+
max_chars_per_message=None,
|
| 169 |
+
rounds_per_game=1,
|
| 170 |
+
random_setup_func=None,
|
| 171 |
+
random_setup_kwargs=None,
|
| 172 |
+
role_assignator_func=None,
|
| 173 |
+
role_assignator_func_kwargs=None,
|
| 174 |
+
finalization_visibility=False,
|
| 175 |
+
other_values_visibility=False,
|
| 176 |
+
random_seed=None
|
| 177 |
+
):
|
| 178 |
+
"""Initialize the Deal or No Deal environment.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
agents: List of agent IDs participating in the game
|
| 182 |
+
mode: Game mode ("coop" or "comp")
|
| 183 |
+
max_messages: Maximum number of messages per agent per round
|
| 184 |
+
min_messages: Minimum number of messages per agent per round
|
| 185 |
+
max_chars_per_message: Maximum characters per message
|
| 186 |
+
rounds_per_game: Number of negotiation rounds to play
|
| 187 |
+
random_setup_func: Function to generate item quantities and values
|
| 188 |
+
random_setup_kwargs: Arguments for the random setup function
|
| 189 |
+
role_assignator_func: Function to assign roles to agents
|
| 190 |
+
role_assignator_func_kwargs: Arguments for the role assignator
|
| 191 |
+
finalization_visibility: Whether agents can see each other's finalizations
|
| 192 |
+
other_values_visibility: Whether agents can see each other's values
|
| 193 |
+
random_seed: Seed for reproducibility
|
| 194 |
+
"""
|
| 195 |
+
# ...
|
| 196 |
+
|
| 197 |
+
def reset(self):
|
| 198 |
+
"""Reset the environment to an initial state and return the initial observation.
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
observation (dict): A dictionary where keys are agent identifiers and values are observations.
|
| 202 |
+
"""
|
| 203 |
+
# ...
|
| 204 |
+
|
| 205 |
+
def step(self, actions):
|
| 206 |
+
"""Take a step in the environment using the provided actions.
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
actions (dict): A dictionary where keys are agent identifiers and values are actions.
|
| 210 |
+
Actions can be messages or finalization proposals.
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
observations (dict): A dictionary where keys are agent identifiers and values are observations.
|
| 214 |
+
done (bool): Whether the episode has ended.
|
| 215 |
+
info (dict): Additional information about the environment.
|
| 216 |
+
"""
|
| 217 |
+
# ...
|
| 218 |
+
|
| 219 |
+
def get_state(self):
|
| 220 |
+
"""Retrieve the current state of the game.
|
| 221 |
+
|
| 222 |
+
Returns:
|
| 223 |
+
state (dict): The current state of the game, including items, quantities, values, etc.
|
| 224 |
+
"""
|
| 225 |
+
# ...
|
| 226 |
+
|
| 227 |
+
Key Implementation Details
|
| 228 |
+
~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 229 |
+
|
| 230 |
+
The ``DondEnv`` class implements several key features:
|
| 231 |
+
|
| 232 |
+
1. **Multi-Agent Support**: The environment tracks two agents and manages their alternating messages.
|
| 233 |
+
|
| 234 |
+
2. **Turn-Based Dialogue**: The environment enforces turn structure and limits on message count.
|
| 235 |
+
|
| 236 |
+
3. **Finalization Processing**: The environment validates and processes finalization proposals.
|
| 237 |
+
|
| 238 |
+
4. **Random Setup**: The environment supports multiple methods of generating negotiation scenarios.
|
| 239 |
+
|
| 240 |
+
5. **Round Management**: The environment can handle multiple rounds with different setups.
|
| 241 |
+
|
| 242 |
+
Observation Structure
|
| 243 |
+
~~~~~~~~~~~~~~~~~~~~
|
| 244 |
+
|
| 245 |
+
Each agent receives an observation (state) dictionary with rich information about the game:
|
| 246 |
+
|
| 247 |
+
.. code-block:: python
|
| 248 |
+
|
| 249 |
+
{
|
| 250 |
+
"mode": str, # Game mode ("coop" or "comp")
|
| 251 |
+
"role_values": dict, # Value mappings for each role
|
| 252 |
+
"role_props": dict, # Properties for each role
|
| 253 |
+
"agent_to_role": dict, # Mapping from agent IDs to roles
|
| 254 |
+
"is_new_round": bool, # Whether this is the start of a new round
|
| 255 |
+
"is_new_game": bool, # Whether this is the start of a new game
|
| 256 |
+
"game_over": bool, # Whether the game is over
|
| 257 |
+
"items": list, # List of item names
|
| 258 |
+
"quantities": dict, # Quantities of each item
|
| 259 |
+
"has_finalized": bool, # Whether finalization has been proposed
|
| 260 |
+
"last_message": dict, # The last message sent
|
| 261 |
+
"messages_remaining": dict, # Number of messages each agent can still send
|
| 262 |
+
# And various history tracking fields
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
Action Structure
|
| 266 |
+
~~~~~~~~~~~~~~~
|
| 267 |
+
|
| 268 |
+
Actions can be:
|
| 269 |
+
|
| 270 |
+
1. **Text Messages**: Free-form text for negotiation.
|
| 271 |
+
2. **Finalization Proposals**: Structured data specifying the exact allocation of items.
|
| 272 |
+
|
| 273 |
+
Example finalization format:
|
| 274 |
+
|
| 275 |
+
.. code-block:: python
|
| 276 |
+
|
| 277 |
+
{
|
| 278 |
+
"type": "finalize",
|
| 279 |
+
"allocation": {
|
| 280 |
+
"agent1": {"book": 3, "hat": 0, "ball": 6},
|
| 281 |
+
"agent2": {"book": 1, "hat": 2, "ball": 0}
|
| 282 |
+
}
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
Value Setup Functions
|
| 286 |
+
--------------------
|
| 287 |
+
|
| 288 |
+
The DoND environment provides several functions for setting up item values:
|
| 289 |
+
|
| 290 |
+
.. code-block:: python
|
| 291 |
+
|
| 292 |
+
def dond_random_setup(items, min_quant, max_quant, min_val, max_val, random_seed=None):
|
| 293 |
+
"""
|
| 294 |
+
Generates items, even-numbered quantities and distinct random values for each category for both agents.
|
| 295 |
+
|
| 296 |
+
Args:
|
| 297 |
+
items (list): List of items.
|
| 298 |
+
min_quant (int): Minimum quantity per item.
|
| 299 |
+
max_quant (int): Maximum quantity per item.
|
| 300 |
+
min_val (int): Minimum value per item.
|
| 301 |
+
max_val (int): Maximum value per item.
|
| 302 |
+
random_seed (int, optional): Seed for random generation.
|
| 303 |
+
|
| 304 |
+
Returns:
|
| 305 |
+
tuple: (items, quantities, (val_starting_negotiator, val_responding_negotiator))
|
| 306 |
+
"""
|
| 307 |
+
# ...
|
| 308 |
+
|
| 309 |
+
def independent_random_vals(items, min_quant, max_quant, min_val, max_val, random_seed=None):
|
| 310 |
+
"""
|
| 311 |
+
Generates random quantities and independent random values for both agents.
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
Similar to dond_random_setup
|
| 315 |
+
|
| 316 |
+
Returns:
|
| 317 |
+
tuple: (items, quantities, (val_starting_negotiator, val_responding_negotiator))
|
| 318 |
+
"""
|
| 319 |
+
# ...
|
| 320 |
+
|
| 321 |
+
def bicameral_vals_assignator(items, min_quant, max_quant, low_val_mean, low_val_std, high_val_mean, high_val_std, random_seed=None):
|
| 322 |
+
"""
|
| 323 |
+
Generates values with a bicameral distribution - each agent values half the items highly.
|
| 324 |
+
|
| 325 |
+
Args:
|
| 326 |
+
items (list): List of items.
|
| 327 |
+
min_quant, max_quant: Range for quantities
|
| 328 |
+
low_val_mean, low_val_std: Mean and standard deviation for the "low value" distribution
|
| 329 |
+
high_val_mean, high_val_std: Mean and standard deviation for the "high value" distribution
|
| 330 |
+
random_seed: Seed for reproducibility
|
| 331 |
+
|
| 332 |
+
Returns:
|
| 333 |
+
tuple: (items, quantities, (val_starting_negotiator, val_responding_negotiator))
|
| 334 |
+
"""
|
| 335 |
+
# ...
|
| 336 |
+
|
| 337 |
+
Running DoND Games
|
| 338 |
+
----------------------
|
| 339 |
+
|
| 340 |
+
To run Deal or No Deal games with LLM agents, you can use the following structure:
|
| 341 |
+
|
| 342 |
+
.. code-block:: python
|
| 343 |
+
|
| 344 |
+
from mllm.environments.dond.dond_game import DondEnv
|
| 345 |
+
from mllm.environments.dond.dond_agent import DondAgent
|
| 346 |
+
from src.run_matches import run_batched_matches
|
| 347 |
+
|
| 348 |
+
# Create environment
|
| 349 |
+
env = DondEnv(
|
| 350 |
+
agents=["agent1", "agent2"],
|
| 351 |
+
mode="coop",
|
| 352 |
+
max_messages=10,
|
| 353 |
+
rounds_per_game=1,
|
| 354 |
+
random_setup_func="dond_random_setup",
|
| 355 |
+
random_setup_kwargs={
|
| 356 |
+
"items": ["book", "hat", "ball"],
|
| 357 |
+
"min_quant": 2,
|
| 358 |
+
"max_quant": 8,
|
| 359 |
+
"min_val": 1,
|
| 360 |
+
"max_val": 10
|
| 361 |
+
},
|
| 362 |
+
finalization_visibility=False
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
# Create agent handlers (implementation details would vary)
|
| 366 |
+
agent_handlers = {
|
| 367 |
+
"agent1": DondAgent(agent_id="agent1"),
|
| 368 |
+
"agent2": DondAgent(agent_id="agent2")
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
# Define policy mapping
|
| 372 |
+
policy_mapping = {
|
| 373 |
+
"llm_policy": my_llm_policy_function
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
# Run the game
|
| 377 |
+
game_results = run_batched_matches(
|
| 378 |
+
envs=[env],
|
| 379 |
+
agent_handlers_per_env=[agent_handlers],
|
| 380 |
+
policy_mapping=policy_mapping,
|
| 381 |
+
max_parallel_matches=1
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
Limitations and Considerations
|
| 385 |
+
-----------------------------
|
| 386 |
+
|
| 387 |
+
1. **Negotiation Complexity**: The open-ended nature of negotiations can be challenging for some LLM agents.
|
| 388 |
+
|
| 389 |
+
2. **Parsing Challenges**: Extracting structured finalization proposals from free-form text requires robust parsing.
|
| 390 |
+
|
| 391 |
+
3. **Optimization Opportunities**: Different agents may employ different negotiation strategies to optimize outcomes.
|
| 392 |
+
|
| 393 |
+
4. **Fairness Evaluation**: The environment allows research into questions of fair division and Pareto optimality.
|
| 394 |
+
|
| 395 |
+
5. **Strategic Deception**: Agents might strategically misrepresent their true values, adding complexity to negotiations.
|
| 396 |
+
|
| 397 |
+
Advanced Usage
|
| 398 |
+
------------
|
| 399 |
+
|
| 400 |
+
For advanced usage, you can:
|
| 401 |
+
|
| 402 |
+
1. **Custom Value Functions**: Create more complex distributions of item values for specific research questions.
|
| 403 |
+
|
| 404 |
+
2. **Novel Negotiation Scenarios**: Design item sets and values to test specific negotiation skills.
|
| 405 |
+
|
| 406 |
+
3. **Curriculum Learning**: Create progressively more difficult negotiation scenarios.
|
| 407 |
+
|
| 408 |
+
4. **Communication Analysis**: Analyze the language and strategies used in successful negotiations.
|
| 409 |
+
|
| 410 |
+
5. **Multi-Round Dynamics**: Study how agents adapt their strategies over multiple rounds.
|
src_code_for_reproducibility/docs/source/environments/ipd.rst
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
=================
|
| 2 |
+
Iterated Prisoner's Dilemma
|
| 3 |
+
=================
|
| 4 |
+
|
| 5 |
+
The Iterated Prisoner's Dilemma environment provides a classic game theory setting for studying cooperation
|
| 6 |
+
and competition between agents. This document describes the API for interacting with the IPD environment
|
| 7 |
+
and its associated agent handler.
|
| 8 |
+
|
| 9 |
+
Overview
|
| 10 |
+
--------
|
| 11 |
+
|
| 12 |
+
The Prisoner's Dilemma is a fundamental problem in game theory that demonstrates why two rational individuals might not
|
| 13 |
+
cooperate, even when it appears in their best interest to do so. In the iterated version, the same two players
|
| 14 |
+
repeatedly face the same dilemma, allowing for the development of trust or retaliation based on previous interactions.
|
| 15 |
+
|
| 16 |
+
Our implementation follows the Multi-Agent Negotiation Environment standard, allowing it to be used with
|
| 17 |
+
LLM agents through a text-based interface.
|
| 18 |
+
|
| 19 |
+
Game Rules
|
| 20 |
+
----------
|
| 21 |
+
|
| 22 |
+
### Basic Premise
|
| 23 |
+
|
| 24 |
+
The scenario behind the Prisoner's Dilemma is as follows:
|
| 25 |
+
|
| 26 |
+
Two criminals are arrested and imprisoned. Each prisoner is in solitary confinement with no means of communicating with
|
| 27 |
+
the other. The prosecutors lack sufficient evidence to convict the pair on the principal charge, but they have enough
|
| 28 |
+
to convict both on a lesser charge. Simultaneously, the prosecutors offer each prisoner a bargain:
|
| 29 |
+
|
| 30 |
+
- If both prisoners betray each other, each serves 2 years in prison (the "punishment" payoff)
|
| 31 |
+
- If one betrays the other while the other remains silent, the betrayer goes free (the "temptation" payoff) while the
|
| 32 |
+
silent accomplice serves 3 years (the "sucker" payoff)
|
| 33 |
+
- If both remain silent, each serves only 1 year in prison (the "reward" payoff)
|
| 34 |
+
|
| 35 |
+
### Game Mechanics
|
| 36 |
+
|
| 37 |
+
In our implementation, the choices are simplified to:
|
| 38 |
+
- **C**: Cooperate (remain silent)
|
| 39 |
+
- **D**: Defect (betray the other prisoner)
|
| 40 |
+
|
| 41 |
+
Each round, both players simultaneously choose either C or D, and receive points based on the combination of their choices:
|
| 42 |
+
|
| 43 |
+
- Both choose C: Both receive the "reward" payoff (3 points by default)
|
| 44 |
+
- Both choose D: Both receive the "punishment" payoff (1 point by default)
|
| 45 |
+
- One chooses C, one chooses D: The defector receives the "temptation" payoff (5 points by default), while the cooperator
|
| 46 |
+
receives the "sucker" payoff (0 points by default)
|
| 47 |
+
|
| 48 |
+
### Example: Single Round
|
| 49 |
+
|
| 50 |
+
Let's see how a single round plays out:
|
| 51 |
+
|
| 52 |
+
1. Alice and Bob simultaneously make their choices
|
| 53 |
+
2. If Alice chooses C and Bob chooses C:
|
| 54 |
+
- Alice receives 3 points
|
| 55 |
+
- Bob receives 3 points
|
| 56 |
+
3. If Alice chooses C and Bob chooses D:
|
| 57 |
+
- Alice receives 0 points
|
| 58 |
+
- Bob receives 5 points
|
| 59 |
+
4. If Alice chooses D and Bob chooses C:
|
| 60 |
+
- Alice receives 5 points
|
| 61 |
+
- Bob receives 0 points
|
| 62 |
+
5. If Alice chooses D and Bob chooses D:
|
| 63 |
+
- Alice receives 1 point
|
| 64 |
+
- Bob receives 1 point
|
| 65 |
+
|
| 66 |
+
### Iterated Game Structure
|
| 67 |
+
|
| 68 |
+
The iterated version repeats this basic game for a fixed number of rounds. The key features are:
|
| 69 |
+
|
| 70 |
+
1. Players know the total number of rounds in advance
|
| 71 |
+
2. After each round, players learn what choice the other player made
|
| 72 |
+
3. Players maintain a cumulative score across all rounds
|
| 73 |
+
4. Players can adjust their strategy based on the history of previous interactions
|
| 74 |
+
|
| 75 |
+
### Game Variations
|
| 76 |
+
|
| 77 |
+
The IPD environment supports several variations through configuration parameters:
|
| 78 |
+
|
| 79 |
+
#### Different Payoff Matrices
|
| 80 |
+
|
| 81 |
+
The standard payoff values can be modified to create different incentive structures:
|
| 82 |
+
- **Traditional PD**: reward=3, punishment=1, temptation=5, sucker=0
|
| 83 |
+
- **Weak Temptation**: reward=3, punishment=1, temptation=4, sucker=0 (reduces the incentive to defect)
|
| 84 |
+
- **Harsh Punishment**: reward=3, punishment=0, temptation=5, sucker=0 (increases the cost of mutual defection)
|
| 85 |
+
- **Generous**: reward=4, punishment=2, temptation=5, sucker=1 (cushions the blow of being betrayed)
|
| 86 |
+
|
| 87 |
+
#### Game Length Variations
|
| 88 |
+
|
| 89 |
+
The number of rounds can significantly impact strategy:
|
| 90 |
+
- **Short Games** (5-10 rounds): Incentivizes more defection, especially near the end
|
| 91 |
+
- **Medium Games** (20-50 rounds): Allows for the development of tit-for-tat and forgiveness strategies
|
| 92 |
+
- **Long Games** (100+ rounds): Favors steady cooperation with occasional "probing" defections
|
| 93 |
+
|
| 94 |
+
### Common Strategies
|
| 95 |
+
|
| 96 |
+
While not enforced by the environment, several well-known strategies can emerge:
|
| 97 |
+
- **Always Cooperate**: Always choose C
|
| 98 |
+
- **Always Defect**: Always choose D
|
| 99 |
+
- **Tit for Tat**: Start with C, then copy what the opponent did in the previous round
|
| 100 |
+
- **Forgiving Tit for Tat**: Like Tit for Tat, but occasionally cooperate even after being defected against
|
| 101 |
+
- **Grudger**: Cooperate until the opponent defects once, then always defect
|
| 102 |
+
- **Random**: Choose randomly between C and D
|
| 103 |
+
|
| 104 |
+
IPDEnv
|
| 105 |
+
------
|
| 106 |
+
|
| 107 |
+
The ``IPDEnv`` class provides an interface to the Iterated Prisoner's Dilemma environment that follows the
|
| 108 |
+
Multi-Agent Negotiation Environment standard.
|
| 109 |
+
|
| 110 |
+
.. code-block:: python
|
| 111 |
+
|
| 112 |
+
class IPDEnv:
|
| 113 |
+
"""
|
| 114 |
+
Iterated Prisoner's Dilemma environment following the MarlEnvironment standard.
|
| 115 |
+
|
| 116 |
+
In each round of the game, two agents simultaneously choose to either cooperate (C) or defect (D).
|
| 117 |
+
The payoffs are as follows:
|
| 118 |
+
- If both cooperate: Both receive the "reward" (usually 3 points)
|
| 119 |
+
- If both defect: Both receive the "punishment" (usually 1 point)
|
| 120 |
+
- If one cooperates and one defects: The defector receives the "temptation" (usually 5 points)
|
| 121 |
+
and the cooperator receives the "sucker" payoff (usually 0 points)
|
| 122 |
+
|
| 123 |
+
The game is played for a specified number of rounds.
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
def __init__(
|
| 127 |
+
self,
|
| 128 |
+
rounds_per_game: int = 10,
|
| 129 |
+
reward: float = 3.0, # Both cooperate
|
| 130 |
+
punishment: float = 1.0, # Both defect
|
| 131 |
+
temptation: float = 5.0, # Defector's reward when other cooperates
|
| 132 |
+
sucker: float = 0.0, # Cooperator's reward when other defects
|
| 133 |
+
random_seed: Optional[int] = None,
|
| 134 |
+
):
|
| 135 |
+
"""
|
| 136 |
+
Initialize the Iterated Prisoner's Dilemma environment.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
rounds_per_game: Number of rounds to play
|
| 140 |
+
reward: Payoff when both agents cooperate
|
| 141 |
+
punishment: Payoff when both agents defect
|
| 142 |
+
temptation: Payoff for defecting when other agent cooperates
|
| 143 |
+
sucker: Payoff for cooperating when other agent defects
|
| 144 |
+
seed: Random seed for reproducibility
|
| 145 |
+
"""
|
| 146 |
+
# ...
|
| 147 |
+
|
| 148 |
+
def reset(self) -> Dict[str, Dict[str, Any]]:
|
| 149 |
+
"""
|
| 150 |
+
Reset the environment to an initial state and return the initial observation.
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
observation (dict): A dictionary where keys are agent identifiers and values are observations.
|
| 154 |
+
"""
|
| 155 |
+
# ...
|
| 156 |
+
|
| 157 |
+
def step(self, actions: Dict[str, str]) -> Tuple[Dict[str, Dict[str, Any]], bool, Dict[str, Any]]:
|
| 158 |
+
"""
|
| 159 |
+
Take a step in the environment using the provided actions.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
actions (dict): A dictionary where keys are agent identifiers and values are actions ('C' or 'D').
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
observations (dict): A dictionary where keys are agent identifiers and values are observations.
|
| 166 |
+
done (bool): Whether the episode has ended.
|
| 167 |
+
info (dict): Additional information about the environment.
|
| 168 |
+
"""
|
| 169 |
+
# ...
|
| 170 |
+
|
| 171 |
+
Key Implementation Details
|
| 172 |
+
~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 173 |
+
|
| 174 |
+
The ``IPDEnv`` class implements several key features:
|
| 175 |
+
|
| 176 |
+
1. **Two-Agent Support**: The environment tracks two agents ("alice" and "bob") and manages their interactions.
|
| 177 |
+
|
| 178 |
+
2. **Round-Based Play**: The environment enforces turn structure and tracks game history.
|
| 179 |
+
|
| 180 |
+
3. **Payoff Matrix**: The environment calculates rewards based on the standard prisoner's dilemma payoff matrix.
|
| 181 |
+
|
| 182 |
+
4. **Observation Generation**: The environment generates detailed observations for each agent, including action history and rewards.
|
| 183 |
+
|
| 184 |
+
5. **Game Termination**: The environment tracks game termination after the specified number of rounds.
|
| 185 |
+
|
| 186 |
+
Observation Structure
|
| 187 |
+
~~~~~~~~~~~~~~~~~~~~
|
| 188 |
+
|
| 189 |
+
Each agent receives an observation dictionary with the following structure:
|
| 190 |
+
|
| 191 |
+
.. code-block:: python
|
| 192 |
+
|
| 193 |
+
{
|
| 194 |
+
"current_round": int, # Current round number (0-indexed)
|
| 195 |
+
"rounds_per_game": int, # Total number of rounds in the game
|
| 196 |
+
"history": List[Dict], # Complete game history so far
|
| 197 |
+
"last_round_actions": Dict[str, str], # Actions from the previous round (if any)
|
| 198 |
+
"last_round_reward": float, # Reward received in the previous round (if any)
|
| 199 |
+
"total_reward": float, # Cumulative reward so far
|
| 200 |
+
"payoff_matrix": Dict[str, float], # The game's payoff matrix values
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
Action Structure
|
| 204 |
+
~~~~~~~~~~~~~~~
|
| 205 |
+
|
| 206 |
+
Actions are simple strings:
|
| 207 |
+
|
| 208 |
+
1. ``"C"`` for Cooperate
|
| 209 |
+
2. ``"D"`` for Defect
|
| 210 |
+
|
| 211 |
+
IPDAgent
|
| 212 |
+
--------------
|
| 213 |
+
|
| 214 |
+
The ``IPDAgent`` class implements the agent handler interface for the Iterated Prisoner's Dilemma, processing observations from the environment and generating actions through an LLM.
|
| 215 |
+
|
| 216 |
+
.. code-block:: python
|
| 217 |
+
|
| 218 |
+
class IPDAgent:
|
| 219 |
+
"""
|
| 220 |
+
Agent handler for Iterated Prisoner's Dilemma, implementing the AgentState interface
|
| 221 |
+
for the multi-agent negotiation standard.
|
| 222 |
+
"""
|
| 223 |
+
|
| 224 |
+
def __init__(
|
| 225 |
+
self,
|
| 226 |
+
agent_id: str,
|
| 227 |
+
policy_id: str = "llm_policy",
|
| 228 |
+
system_prompt: Optional[str] = None,
|
| 229 |
+
max_errors: int = 3,
|
| 230 |
+
opponent_id: Optional[str] = None,
|
| 231 |
+
):
|
| 232 |
+
"""
|
| 233 |
+
Initialize the IPD agent handler.
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
agent_id: Identifier for this agent ("alice" or "bob")
|
| 237 |
+
policy_id: Identifier for the policy this agent uses
|
| 238 |
+
system_prompt: Optional custom system prompt for the LLM
|
| 239 |
+
max_errors: Maximum number of parsing errors before defaulting to cooperate
|
| 240 |
+
opponent_id: Optional identifier of the opponent (inferred if not provided)
|
| 241 |
+
"""
|
| 242 |
+
# ...
|
| 243 |
+
|
| 244 |
+
def step(self, observation_from_env: Dict[str, Any], policy_output: str = None) -> Tuple[str, Dict[str, Any], str, bool, Dict[str, Any]]:
|
| 245 |
+
"""
|
| 246 |
+
Update the agent state based on the observation and process the policy output.
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
observation_from_env: The observation from the environment
|
| 250 |
+
policy_output: The output from the policy (LLM response)
|
| 251 |
+
|
| 252 |
+
Returns:
|
| 253 |
+
policy_id: The policy identifier
|
| 254 |
+
policy_input: The input to the policy
|
| 255 |
+
action: The action to be sent to the environment
|
| 256 |
+
done: Whether the action is ready to be sent to the environment
|
| 257 |
+
info: Additional information about the agent
|
| 258 |
+
"""
|
| 259 |
+
# ...
|
| 260 |
+
|
| 261 |
+
Key Implementation Details
|
| 262 |
+
~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 263 |
+
|
| 264 |
+
The ``IPDAgent`` class implements several key features:
|
| 265 |
+
|
| 266 |
+
1. **LLM Interaction**: The agent generates prompts for an LLM and processes the LLM's responses.
|
| 267 |
+
|
| 268 |
+
2. **Action Extraction**: The agent parses the LLM's output to extract valid actions (C or D).
|
| 269 |
+
|
| 270 |
+
3. **Error Handling**: The agent provides helpful error messages when parsing fails and defaults to cooperation after multiple failures.
|
| 271 |
+
|
| 272 |
+
4. **History Tracking**: The agent maintains and provides the complete game history in its prompts.
|
| 273 |
+
|
| 274 |
+
5. **Strategy Explanation**: The agent can extract and log the reasoning behind an LLM's decisions.
|
| 275 |
+
|
| 276 |
+
Prompt Structure
|
| 277 |
+
~~~~~~~~~~~~~~~
|
| 278 |
+
|
| 279 |
+
The agent generates prompts that include:
|
| 280 |
+
|
| 281 |
+
1. **System Prompt**: Instructions and context for the LLM, explaining its role and the rules of the Prisoner's Dilemma.
|
| 282 |
+
|
| 283 |
+
2. **Game State Description**: A text description of the current game state, including:
|
| 284 |
+
- Current round number
|
| 285 |
+
- History of previous rounds (if any)
|
| 286 |
+
- Cumulative score
|
| 287 |
+
|
| 288 |
+
3. **Action Request**: Instructions on how to format the response, requiring an explicit action tag.
|
| 289 |
+
|
| 290 |
+
Example system prompt:
|
| 291 |
+
|
| 292 |
+
.. code-block:: text
|
| 293 |
+
|
| 294 |
+
You are playing as Alice in an Iterated Prisoner's Dilemma game against Bob.
|
| 295 |
+
In each round, you must choose to either Cooperate (C) or Defect (D).
|
| 296 |
+
|
| 297 |
+
The payoffs are:
|
| 298 |
+
- If both players Cooperate: You each get 3 points
|
| 299 |
+
- If both players Defect: You each get 1 point
|
| 300 |
+
- If you Cooperate and Bob Defects: You get 0 points, Bob gets 5 points
|
| 301 |
+
- If you Defect and Bob Cooperates: You get 5 points, Bob gets 0 points
|
| 302 |
+
|
| 303 |
+
Your goal is to maximize your total points across all rounds.
|
| 304 |
+
The game will last for exactly 10 rounds, and both players know this.
|
| 305 |
+
|
| 306 |
+
Example game state prompt:
|
| 307 |
+
|
| 308 |
+
.. code-block:: text
|
| 309 |
+
|
| 310 |
+
Current round: 3/10
|
| 311 |
+
|
| 312 |
+
History:
|
| 313 |
+
Round 1: You chose C, Bob chose C. You earned 3 points.
|
| 314 |
+
Round 2: You chose C, Bob chose D. You earned 0 points.
|
| 315 |
+
|
| 316 |
+
Your total score so far: 3 points
|
| 317 |
+
|
| 318 |
+
What is your choice for round 3?
|
| 319 |
+
Please respond with <action>C</action> to cooperate or <action>D</action> to defect,
|
| 320 |
+
and explain your reasoning.
|
| 321 |
+
|
| 322 |
+
Running IPD Games
|
| 323 |
+
----------------------
|
| 324 |
+
|
| 325 |
+
To run Iterated Prisoner's Dilemma games with LLM agents, you can use the following code structure:
|
| 326 |
+
|
| 327 |
+
.. code-block:: python
|
| 328 |
+
|
| 329 |
+
from mllm.environments.ipd.ipd_game import IPDEnv
|
| 330 |
+
from mllm.environments.ipd.ipd_agent import IPDAgent
|
| 331 |
+
from mllm.run_matches import run_batched_matches
|
| 332 |
+
|
| 333 |
+
# Create environment
|
| 334 |
+
env = IPDEnv(
|
| 335 |
+
rounds_per_game=10,
|
| 336 |
+
reward=3.0,
|
| 337 |
+
punishment=1.0,
|
| 338 |
+
temptation=5.0,
|
| 339 |
+
sucker=0.0
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
# Create agent handlers
|
| 343 |
+
agent_handlers = {
|
| 344 |
+
"alice": IPDAgent(agent_id="alice"),
|
| 345 |
+
"bob": IPDAgent(agent_id="bob")
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
# Define policy mapping
|
| 349 |
+
policy_mapping = {
|
| 350 |
+
"llm_policy": my_llm_policy_function
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
# Run the game
|
| 354 |
+
game_results = run_batched_matches(
|
| 355 |
+
envs=[env],
|
| 356 |
+
agent_handlers_per_env=[agent_handlers],
|
| 357 |
+
policy_mapping=policy_mapping,
|
| 358 |
+
max_parallel_matches=1
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
# Process results
|
| 362 |
+
for result in game_results:
|
| 363 |
+
print(f"Game finished. Scores: {result['total_rewards']}")
|
| 364 |
+
|
| 365 |
+
Statistics and Analysis
|
| 366 |
+
----------------------
|
| 367 |
+
|
| 368 |
+
The IPD environment includes utility functions for analyzing game outcomes:
|
| 369 |
+
|
| 370 |
+
1. **Cooperation Rates**: Percentage of rounds where each agent cooperated.
|
| 371 |
+
2. **Mutual Cooperation/Defection**: Percentage of rounds where both agents made the same choice.
|
| 372 |
+
3. **Score Distribution**: Analysis of how points were accumulated over the game.
|
| 373 |
+
|
| 374 |
+
These statistics can be calculated using the ``gather_ipd_statistics`` function:
|
| 375 |
+
|
| 376 |
+
.. code-block:: python
|
| 377 |
+
|
| 378 |
+
from mllm.environments.ipd.ipd_statistics_funcs import gather_ipd_statistics
|
| 379 |
+
|
| 380 |
+
stats = gather_ipd_statistics(match_info, env_info)
|
| 381 |
+
print(f"Cooperation rates: {stats['cooperation_rate']}")
|
| 382 |
+
print(f"Mutual cooperation rate: {stats['mutual_cooperation_rate']}")
|
| 383 |
+
print(f"Mutual defection rate: {stats['mutual_defection_rate']}")
|
| 384 |
+
|
| 385 |
+
Limitations and Considerations
|
| 386 |
+
-----------------------------
|
| 387 |
+
|
| 388 |
+
1. **Determinism**: The environment is deterministic, with randomness only in initialization if a seed is provided.
|
| 389 |
+
|
| 390 |
+
2. **Limited Player Count**: The IPD environment only supports exactly two players.
|
| 391 |
+
|
| 392 |
+
3. **Perfect Information**: Both players have perfect information about the game history.
|
| 393 |
+
|
| 394 |
+
4. **Simultaneous Actions**: Both players act simultaneously, which requires adaptations for some LLM interfaces.
|
| 395 |
+
|
| 396 |
+
5. **Fixed Game Length**: The total number of rounds is fixed and known to both players from the start.
|
| 397 |
+
|
| 398 |
+
Advanced Usage
|
| 399 |
+
------------
|
| 400 |
+
|
| 401 |
+
For advanced usage, you can customize:
|
| 402 |
+
|
| 403 |
+
1. **Payoff Matrix**: Modify reward values to create different incentive structures.
|
| 404 |
+
|
| 405 |
+
2. **System Prompts**: Customize the LLM's understanding of the game and potential strategies.
|
| 406 |
+
|
| 407 |
+
3. **Error Handling**: Adjust how the agent responds to invalid LLM outputs.
|
| 408 |
+
|
| 409 |
+
4. **Analysis**: Create custom statistics gathering for specific research questions.
|
| 410 |
+
|
| 411 |
+
5. **Integration**: Connect the IPD environment to other negotiation frameworks or tournament systems.
|
src_code_for_reproducibility/docs/source/launch.rst
ADDED
|
File without changes
|
src_code_for_reproducibility/docs/source/media/runbatch.png
ADDED
|
src_code_for_reproducibility/docs/source/modules.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src
|
| 2 |
+
===
|
| 3 |
+
|
| 4 |
+
.. toctree::
|
| 5 |
+
:maxdepth: 4
|
| 6 |
+
|
| 7 |
+
src
|
src_code_for_reproducibility/docs/source/src.generation.run_games.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.generation.run\_games module
|
| 2 |
+
================================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.generation.run_games
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.models.dummy_hf_agent.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.models.dummy\_hf\_agent module
|
| 2 |
+
==================================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.models.dummy_llm_agent
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.models.dummy_local_llm.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.models.dummy\_local\_llm module
|
| 2 |
+
===================================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.models.dummy_local_llm
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.models.hf_agent.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.models.hf\_agent module
|
| 2 |
+
===========================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.models.hf_agent
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.models.local_llm.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.models.local\_llm module
|
| 2 |
+
============================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.models.local_llm
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.models.oai_agent.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.models.oai\_agent module
|
| 2 |
+
============================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.models.oai_agent
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.models.server_llm.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.models.server\_llm module
|
| 2 |
+
=============================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.models.server_llm
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.run.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.run module
|
| 2 |
+
==============
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.run
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.training.ppo_train.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.training.ppo\_train module
|
| 2 |
+
==============================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.training.ppo_train
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.training.train_main.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.training.train\_main module
|
| 2 |
+
===============================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.training.train_main
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.utils.common_imports.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.utils.common\_imports module
|
| 2 |
+
================================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.utils.common_imports
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.utils.extra_stats.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.utils.extra\_stats module
|
| 2 |
+
=============================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.utils.extra_stats
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.utils.inherit_args.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.utils.inherit\_args module
|
| 2 |
+
==============================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.utils.inherit_args
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.utils.log_gpu_usage.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.utils.log\_gpu\_usage module
|
| 2 |
+
================================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.utils.log_gpu_usage
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.utils.log_statistics.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.utils.log\_statistics module
|
| 2 |
+
================================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.utils.log_statistics
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.utils.model_to_cpu.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.utils.model\_to\_cpu module
|
| 2 |
+
===============================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.utils.model_to_cpu
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.utils.quick_stats.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.utils.quick\_stats module
|
| 2 |
+
=============================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.utils.quick_stats
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/usage.rst
ADDED
|
File without changes
|
src_code_for_reproducibility/models/human_policy.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import shutil
|
| 5 |
+
import sys
|
| 6 |
+
from typing import Callable, Dict, List, Optional
|
| 7 |
+
|
| 8 |
+
from mllm.markov_games.rollout_tree import ChatTurn
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
import rstr # For generating example strings from regex
|
| 12 |
+
except Exception: # pragma: no cover
|
| 13 |
+
rstr = None
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _clear_terminal() -> None:
|
| 17 |
+
"""
|
| 18 |
+
Clear the terminal screen in a cross-platform manner.
|
| 19 |
+
"""
|
| 20 |
+
if sys.stdout.isatty():
|
| 21 |
+
os.system("cls" if os.name == "nt" else "clear")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _terminal_width(default: int = 100) -> int:
|
| 25 |
+
try:
|
| 26 |
+
return shutil.get_terminal_size().columns
|
| 27 |
+
except Exception:
|
| 28 |
+
return default
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _horizontal_rule(char: str = "─") -> str:
|
| 32 |
+
width = max(20, _terminal_width() - 2)
|
| 33 |
+
return char * width
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class _Style:
|
| 37 |
+
# ANSI colors (bright, readable)
|
| 38 |
+
RESET = "\033[0m"
|
| 39 |
+
BOLD = "\033[1m"
|
| 40 |
+
DIM = "\033[2m"
|
| 41 |
+
# Foreground colors
|
| 42 |
+
FG_BLUE = "\033[94m" # user/system headers
|
| 43 |
+
FG_GREEN = "\033[92m" # human response header
|
| 44 |
+
FG_YELLOW = "\033[93m" # notices
|
| 45 |
+
FG_RED = "\033[91m" # errors
|
| 46 |
+
FG_MAGENTA = "\033[95m" # regex
|
| 47 |
+
FG_CYAN = "\033[96m" # tips
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _render_chat(state) -> str:
|
| 51 |
+
"""
|
| 52 |
+
Render prior messages in a compact, readable terminal format.
|
| 53 |
+
|
| 54 |
+
Expected message dict keys: {"role": str, "content": str, ...}
|
| 55 |
+
"""
|
| 56 |
+
lines: List[str] = []
|
| 57 |
+
lines.append(_horizontal_rule())
|
| 58 |
+
lines.append(f"{_Style.FG_BLUE}{_Style.BOLD} Conversation so far {_Style.RESET}")
|
| 59 |
+
lines.append(_horizontal_rule())
|
| 60 |
+
for chat in state:
|
| 61 |
+
role = chat.role
|
| 62 |
+
content = str(chat.content).strip()
|
| 63 |
+
# Map roles to display names and colors/emojis
|
| 64 |
+
if role == "assistant":
|
| 65 |
+
header = f"{_Style.FG_GREEN}{_Style.BOLD}HUMAN--🧑💻{_Style.RESET}"
|
| 66 |
+
elif role == "user":
|
| 67 |
+
header = f"{_Style.FG_BLUE}{_Style.BOLD}USER--⚙️{_Style.RESET}"
|
| 68 |
+
else:
|
| 69 |
+
header = f"[{_Style.DIM}{role.upper()}{_Style.RESET}]"
|
| 70 |
+
lines.append(header)
|
| 71 |
+
# Indent content for readability
|
| 72 |
+
for line in content.splitlines() or [""]:
|
| 73 |
+
lines.append(f" {line}")
|
| 74 |
+
lines.append("")
|
| 75 |
+
lines.append(_horizontal_rule())
|
| 76 |
+
return "\n".join(lines)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
async def _async_input(prompt_text: str) -> str:
|
| 80 |
+
"""Non-blocking input using a background thread."""
|
| 81 |
+
return await asyncio.to_thread(input, prompt_text)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _short_regex_example(regex: str, max_len: int = 30) -> Optional[str]:
|
| 85 |
+
"""
|
| 86 |
+
Try to produce a short example string that matches the regex.
|
| 87 |
+
We attempt multiple times and pick the first <= max_len.
|
| 88 |
+
"""
|
| 89 |
+
if rstr is None:
|
| 90 |
+
return None
|
| 91 |
+
try:
|
| 92 |
+
for _ in range(20):
|
| 93 |
+
candidate = rstr.xeger(regex)
|
| 94 |
+
if len(candidate) <= max_len:
|
| 95 |
+
return candidate
|
| 96 |
+
# Fallback to truncation (may break match, so don't return)
|
| 97 |
+
return None
|
| 98 |
+
except Exception:
|
| 99 |
+
return None
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def _detect_input_type(regex: str | None) -> tuple[str, str, str]:
|
| 103 |
+
"""
|
| 104 |
+
Detect what type of input is expected based on the regex pattern.
|
| 105 |
+
Returns (input_type, start_tag, end_tag)
|
| 106 |
+
"""
|
| 107 |
+
if regex is None:
|
| 108 |
+
return "text", "", ""
|
| 109 |
+
|
| 110 |
+
if "message_start" in regex and "message_end" in regex:
|
| 111 |
+
return "message", "<<message_start>>", "<<message_end>>"
|
| 112 |
+
elif "proposal_start" in regex and "proposal_end" in regex:
|
| 113 |
+
return "proposal", "<<proposal_start>>", "<<proposal_end>>"
|
| 114 |
+
else:
|
| 115 |
+
return "text", "", ""
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
async def human_policy(state, agent_id, regex: str | None = None) -> str:
|
| 119 |
+
"""
|
| 120 |
+
Async human-in-the-loop policy.
|
| 121 |
+
|
| 122 |
+
- Displays prior conversation context in the terminal.
|
| 123 |
+
- Prompts the user for a response.
|
| 124 |
+
- If a regex is provided, validates and re-prompts until it matches.
|
| 125 |
+
- Automatically adds formatting tags based on expected input type.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
prompt: Chat history as a list of {role, content} dicts.
|
| 129 |
+
regex: Optional fullmatch validation pattern.
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
The user's validated response string.
|
| 133 |
+
"""
|
| 134 |
+
# Detect input type and formatting
|
| 135 |
+
input_type, start_tag, end_tag = _detect_input_type(regex)
|
| 136 |
+
|
| 137 |
+
while True:
|
| 138 |
+
_clear_terminal()
|
| 139 |
+
print(_render_chat(state))
|
| 140 |
+
|
| 141 |
+
if regex:
|
| 142 |
+
example = _short_regex_example(regex, max_len=30)
|
| 143 |
+
print(
|
| 144 |
+
f"{_Style.FG_MAGENTA}{_Style.BOLD}Expected format (regex fullmatch):{_Style.RESET}"
|
| 145 |
+
)
|
| 146 |
+
print(f" {_Style.FG_MAGENTA}{regex}{_Style.RESET}")
|
| 147 |
+
if example:
|
| 148 |
+
print(
|
| 149 |
+
f"{_Style.FG_CYAN}Example (random, <=30 chars):{_Style.RESET} {example}"
|
| 150 |
+
)
|
| 151 |
+
print(_horizontal_rule("."))
|
| 152 |
+
|
| 153 |
+
# Custom prompt based on input type
|
| 154 |
+
if input_type == "message":
|
| 155 |
+
print(
|
| 156 |
+
f"{_Style.FG_YELLOW}Type your message content (formatting will be added automatically):{_Style.RESET}"
|
| 157 |
+
)
|
| 158 |
+
elif input_type == "proposal":
|
| 159 |
+
print(
|
| 160 |
+
f"{_Style.FG_YELLOW}Type your proposal (number only, formatting will be added automatically):{_Style.RESET}"
|
| 161 |
+
)
|
| 162 |
+
else:
|
| 163 |
+
print(
|
| 164 |
+
f"{_Style.FG_YELLOW}Type your response and press Enter.{_Style.RESET}"
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
print(
|
| 168 |
+
f"{_Style.DIM}Commands: /help to view commands, /refresh to re-render, /quit to abort{_Style.RESET}"
|
| 169 |
+
)
|
| 170 |
+
else:
|
| 171 |
+
print(
|
| 172 |
+
f"{_Style.FG_YELLOW}Type your response and press Enter.{_Style.RESET} {_Style.DIM}(/help for commands){_Style.RESET}"
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
user_in = (await _async_input("> ")).rstrip("\n")
|
| 176 |
+
|
| 177 |
+
# Commands
|
| 178 |
+
if user_in.strip().lower() in {"/help", "/h"}:
|
| 179 |
+
print(f"\n{_Style.FG_CYAN}{_Style.BOLD}Available commands:{_Style.RESET}")
|
| 180 |
+
print(
|
| 181 |
+
f" {_Style.FG_CYAN}/help{_Style.RESET} or {_Style.FG_CYAN}/h{_Style.RESET} Show this help"
|
| 182 |
+
)
|
| 183 |
+
print(
|
| 184 |
+
f" {_Style.FG_CYAN}/refresh{_Style.RESET} or {_Style.FG_CYAN}/r{_Style.RESET} Re-render the conversation and prompt"
|
| 185 |
+
)
|
| 186 |
+
print(
|
| 187 |
+
f" {_Style.FG_CYAN}/quit{_Style.RESET} or {_Style.FG_CYAN}/q{_Style.RESET} Abort the run (raises KeyboardInterrupt)"
|
| 188 |
+
)
|
| 189 |
+
await asyncio.sleep(1.0)
|
| 190 |
+
continue
|
| 191 |
+
if user_in.strip().lower() in {"/refresh", "/r"}:
|
| 192 |
+
continue
|
| 193 |
+
if user_in.strip().lower() in {"/quit", "/q"}:
|
| 194 |
+
raise KeyboardInterrupt("Human aborted run from human_policy")
|
| 195 |
+
|
| 196 |
+
# Add formatting tags if needed
|
| 197 |
+
if start_tag and end_tag:
|
| 198 |
+
formatted_input = f"{start_tag}{user_in}{end_tag}"
|
| 199 |
+
else:
|
| 200 |
+
formatted_input = user_in
|
| 201 |
+
|
| 202 |
+
if regex is None:
|
| 203 |
+
return ChatTurn(
|
| 204 |
+
role="assistant", agent_id=agent_id, content=formatted_input
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# Validate against regex (fullmatch)
|
| 208 |
+
try:
|
| 209 |
+
pattern = re.compile(regex)
|
| 210 |
+
except re.error as e:
|
| 211 |
+
# If regex is invalid, fall back to accepting any input
|
| 212 |
+
print(
|
| 213 |
+
f"{_Style.FG_RED}Warning:{_Style.RESET} Provided regex is invalid: {e}. Accepting input without validation."
|
| 214 |
+
)
|
| 215 |
+
await asyncio.sleep(0.5)
|
| 216 |
+
return ChatTurn(
|
| 217 |
+
role="assistant", agent_id=agent_id, content=formatted_input
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
if pattern.fullmatch(formatted_input):
|
| 221 |
+
return ChatTurn(
|
| 222 |
+
role="assistant", agent_id=agent_id, content=formatted_input
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
# Show validation error and re-prompt
|
| 226 |
+
print("")
|
| 227 |
+
print(
|
| 228 |
+
f"{_Style.FG_RED}{_Style.BOLD}Input did not match the required format.{_Style.RESET} Please try again."
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
if input_type == "message":
|
| 232 |
+
print(
|
| 233 |
+
f"You entered: {_Style.FG_CYAN}{start_tag}{user_in}{end_tag}{_Style.RESET}"
|
| 234 |
+
)
|
| 235 |
+
print(f"Just type the message content without tags.")
|
| 236 |
+
elif input_type == "proposal":
|
| 237 |
+
print(
|
| 238 |
+
f"You entered: {_Style.FG_CYAN}{start_tag}{user_in}{end_tag}{_Style.RESET}"
|
| 239 |
+
)
|
| 240 |
+
print(f"Just type the number without tags.")
|
| 241 |
+
else:
|
| 242 |
+
print(f"Expected (regex):")
|
| 243 |
+
print(f" {_Style.FG_MAGENTA}{regex}{_Style.RESET}")
|
| 244 |
+
|
| 245 |
+
print(_horizontal_rule("."))
|
| 246 |
+
print(f"{_Style.FG_YELLOW}Press Enter to retry...{_Style.RESET}")
|
| 247 |
+
await _async_input("")
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def get_human_policies() -> Dict[str, Callable[[List[Dict]], str]]:
|
| 251 |
+
"""
|
| 252 |
+
Expose the human policy in the same map shape used elsewhere.
|
| 253 |
+
"""
|
| 254 |
+
# Type hint says Callable[[List[Dict]], str] but we intentionally return the async callable.
|
| 255 |
+
return {"human_policy": human_policy} # type: ignore[return-value]
|
src_code_for_reproducibility/training/__init__.py
ADDED
|
File without changes
|
src_code_for_reproducibility/training/tally_tokenwise.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from typing import Any, Dict, List, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import torch
|
| 8 |
+
from transformers import AutoTokenizer
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ContextualizedTokenwiseTally:
|
| 12 |
+
"""
|
| 13 |
+
Collect, store, and save token-level metrics per rollout.
|
| 14 |
+
|
| 15 |
+
- One DataFrame per rollout_id in `paths`
|
| 16 |
+
- Index = timestep (int)
|
| 17 |
+
- Columns are added incrementally via `add_contexts()` and `add_data()`
|
| 18 |
+
- Cells may contain scalars, strings, or lists (dtype=object)
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
tokenizer: AutoTokenizer,
|
| 24 |
+
paths: List[str],
|
| 25 |
+
max_context_length: int = 30,
|
| 26 |
+
):
|
| 27 |
+
"""
|
| 28 |
+
Args:
|
| 29 |
+
tokenizer: HuggingFace tokenizer used to convert tids -> tokens
|
| 30 |
+
paths: rollout identifiers (parallel to batch dimension)
|
| 31 |
+
max_context_length: truncate context token lists to this length
|
| 32 |
+
"""
|
| 33 |
+
self.tokenizer = tokenizer
|
| 34 |
+
self.paths = paths
|
| 35 |
+
self.max_context_length = max_context_length
|
| 36 |
+
self.tally: Dict[str, pd.DataFrame] = {path: pd.DataFrame() for path in paths}
|
| 37 |
+
|
| 38 |
+
# set later by setters
|
| 39 |
+
self.contexts: torch.Tensor | None = None
|
| 40 |
+
self.action_mask: torch.Tensor | None = None
|
| 41 |
+
self.range: Tuple[int, int] | None = None
|
| 42 |
+
|
| 43 |
+
# --------- Utilities ---------
|
| 44 |
+
|
| 45 |
+
def tids_to_str(self, tids: List[int]) -> List[str]:
|
| 46 |
+
"""Convert a list of token IDs to a list of token strings."""
|
| 47 |
+
return self.tokenizer.convert_ids_to_tokens(tids)
|
| 48 |
+
|
| 49 |
+
def _ensure_ready(self):
|
| 50 |
+
assert self.action_mask is not None, "call set_action_mask(mask) first"
|
| 51 |
+
assert self.range is not None, "call set_range((start, end)) first"
|
| 52 |
+
|
| 53 |
+
@staticmethod
|
| 54 |
+
def _sanitize_filename(name: Any) -> str:
|
| 55 |
+
"""Make a safe filename from any rollout_id."""
|
| 56 |
+
s = str(name)
|
| 57 |
+
bad = {os.sep, " ", ":", "|", "<", ">", '"', "'"}
|
| 58 |
+
if os.altsep is not None:
|
| 59 |
+
bad.add(os.altsep)
|
| 60 |
+
for ch in bad:
|
| 61 |
+
s = s.replace(ch, "_")
|
| 62 |
+
return s
|
| 63 |
+
|
| 64 |
+
@staticmethod
|
| 65 |
+
def _pad_left(seq: List[Any], length: int, pad_val: Any = "") -> List[Any]:
|
| 66 |
+
"""Left-pad a sequence to `length` with `pad_val`."""
|
| 67 |
+
if len(seq) >= length:
|
| 68 |
+
return seq[-length:]
|
| 69 |
+
return [pad_val] * (length - len(seq)) + list(seq)
|
| 70 |
+
|
| 71 |
+
# --------- Setters ---------
|
| 72 |
+
|
| 73 |
+
def set_action_mask(self, action_mask: torch.Tensor):
|
| 74 |
+
"""
|
| 75 |
+
action_mask: (B, S) bool or 0/1 indicating valid steps
|
| 76 |
+
"""
|
| 77 |
+
self.action_mask = action_mask
|
| 78 |
+
|
| 79 |
+
def set_range(self, range: Tuple[int, int]):
|
| 80 |
+
"""
|
| 81 |
+
range: slice (start, end) into self.paths for current batch
|
| 82 |
+
"""
|
| 83 |
+
self.range = range
|
| 84 |
+
|
| 85 |
+
# --------- Column builders ---------
|
| 86 |
+
|
| 87 |
+
def add_contexts(self, contexts: torch.Tensor):
|
| 88 |
+
"""
|
| 89 |
+
Add a single 'context' column (list[str]) for valid steps.
|
| 90 |
+
|
| 91 |
+
Expects `contexts` with shape (B, S): token id at each timestep.
|
| 92 |
+
For each valid timestep t, we use the last N tokens up to and including t:
|
| 93 |
+
window = contexts[i, max(0, t - N + 1) : t + 1]
|
| 94 |
+
The list is left-padded with "" to always be length N.
|
| 95 |
+
"""
|
| 96 |
+
self._ensure_ready()
|
| 97 |
+
|
| 98 |
+
current_paths = self.paths[self.range[0] : self.range[1]]
|
| 99 |
+
B, S = contexts.shape
|
| 100 |
+
N = self.max_context_length
|
| 101 |
+
|
| 102 |
+
# to CPU ints once
|
| 103 |
+
contexts_cpu = contexts.detach().to("cpu")
|
| 104 |
+
|
| 105 |
+
for i in range(B):
|
| 106 |
+
rollout_id = current_paths[i]
|
| 107 |
+
df = self.tally.get(rollout_id, pd.DataFrame())
|
| 108 |
+
|
| 109 |
+
valid_idx = torch.nonzero(
|
| 110 |
+
self.action_mask[i].bool(), as_tuple=False
|
| 111 |
+
).squeeze(-1)
|
| 112 |
+
if valid_idx.numel() == 0:
|
| 113 |
+
self.tally[rollout_id] = df
|
| 114 |
+
continue
|
| 115 |
+
|
| 116 |
+
idx_list = valid_idx.tolist()
|
| 117 |
+
|
| 118 |
+
# ensure index contains valid steps
|
| 119 |
+
if df.empty:
|
| 120 |
+
df = pd.DataFrame(index=idx_list)
|
| 121 |
+
else:
|
| 122 |
+
new_index = sorted(set(df.index.tolist()) | set(idx_list))
|
| 123 |
+
if list(df.index) != new_index:
|
| 124 |
+
df = df.reindex(new_index)
|
| 125 |
+
|
| 126 |
+
# build context windows
|
| 127 |
+
ctx_token_lists = []
|
| 128 |
+
for t in idx_list:
|
| 129 |
+
start = max(0, t - N + 1)
|
| 130 |
+
window_ids = contexts_cpu[i, start : t + 1].tolist()
|
| 131 |
+
window_toks = self.tids_to_str([int(x) for x in window_ids])
|
| 132 |
+
if len(window_toks) < N:
|
| 133 |
+
window_toks = [""] * (N - len(window_toks)) + window_toks
|
| 134 |
+
else:
|
| 135 |
+
window_toks = window_toks[-N:]
|
| 136 |
+
ctx_token_lists.append(window_toks)
|
| 137 |
+
|
| 138 |
+
# single 'context' column
|
| 139 |
+
if "context" not in df.columns:
|
| 140 |
+
df["context"] = pd.Series(index=df.index, dtype=object)
|
| 141 |
+
df.loc[idx_list, "context"] = pd.Series(
|
| 142 |
+
ctx_token_lists, index=idx_list, dtype=object
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
self.tally[rollout_id] = df
|
| 146 |
+
|
| 147 |
+
def add_data(
|
| 148 |
+
self,
|
| 149 |
+
metric_id: str,
|
| 150 |
+
metrics: torch.Tensor,
|
| 151 |
+
to_tids: bool = False,
|
| 152 |
+
):
|
| 153 |
+
"""
|
| 154 |
+
Add a metric column for valid steps.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
metric_id: column name
|
| 158 |
+
metrics: shape (B, S) for scalars/ids or (B, S, K) for top-k vectors
|
| 159 |
+
to_tids: if True, treat ints/lists of ints as tids and convert to tokens
|
| 160 |
+
"""
|
| 161 |
+
self._ensure_ready()
|
| 162 |
+
current_paths = self.paths[self.range[0] : self.range[1]]
|
| 163 |
+
|
| 164 |
+
if metrics.dim() == 2:
|
| 165 |
+
B, S = metrics.shape
|
| 166 |
+
elif metrics.dim() == 3:
|
| 167 |
+
B, S, _ = metrics.shape
|
| 168 |
+
else:
|
| 169 |
+
raise ValueError("metrics must be (B, S) or (B, S, K)")
|
| 170 |
+
|
| 171 |
+
for i in range(B):
|
| 172 |
+
rollout_id = current_paths[i]
|
| 173 |
+
df = self.tally.get(rollout_id, pd.DataFrame())
|
| 174 |
+
|
| 175 |
+
valid_idx = torch.nonzero(
|
| 176 |
+
self.action_mask[i].bool(), as_tuple=False
|
| 177 |
+
).squeeze(-1)
|
| 178 |
+
if valid_idx.numel() == 0:
|
| 179 |
+
self.tally[rollout_id] = df
|
| 180 |
+
continue
|
| 181 |
+
|
| 182 |
+
idx_list = valid_idx.detach().cpu().tolist()
|
| 183 |
+
|
| 184 |
+
# Ensure index contains valid steps
|
| 185 |
+
if df.empty:
|
| 186 |
+
df = pd.DataFrame(index=idx_list)
|
| 187 |
+
else:
|
| 188 |
+
new_index = sorted(set(df.index.tolist()) | set(idx_list))
|
| 189 |
+
if list(df.index) != new_index:
|
| 190 |
+
df = df.reindex(new_index)
|
| 191 |
+
|
| 192 |
+
# Slice metrics at valid steps
|
| 193 |
+
m_valid = metrics[i][valid_idx]
|
| 194 |
+
|
| 195 |
+
# -> pure python lists (1D list or list-of-lists)
|
| 196 |
+
values = m_valid.detach().cpu().tolist()
|
| 197 |
+
|
| 198 |
+
# optional tids -> tokens
|
| 199 |
+
if to_tids:
|
| 200 |
+
|
| 201 |
+
def _to_tokish(x):
|
| 202 |
+
if isinstance(x, list):
|
| 203 |
+
return self.tids_to_str([int(v) for v in x])
|
| 204 |
+
else:
|
| 205 |
+
return self.tids_to_str([int(x)])[0]
|
| 206 |
+
|
| 207 |
+
values = [_to_tokish(v) for v in values]
|
| 208 |
+
|
| 209 |
+
# Ensure column exists with object dtype, then assign via aligned Series
|
| 210 |
+
if metric_id not in df.columns:
|
| 211 |
+
df[metric_id] = pd.Series(index=df.index, dtype=object)
|
| 212 |
+
|
| 213 |
+
if isinstance(values, np.ndarray):
|
| 214 |
+
values = values.tolist()
|
| 215 |
+
|
| 216 |
+
if len(values) != len(idx_list):
|
| 217 |
+
raise ValueError(
|
| 218 |
+
f"Length mismatch for '{metric_id}': values={len(values)} vs idx_list={len(idx_list)}"
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
df.loc[idx_list, metric_id] = pd.Series(
|
| 222 |
+
values, index=idx_list, dtype=object
|
| 223 |
+
)
|
| 224 |
+
self.tally[rollout_id] = df
|
| 225 |
+
|
| 226 |
+
# --------- Saving ---------
|
| 227 |
+
|
| 228 |
+
def save(self, path: str):
|
| 229 |
+
"""
|
| 230 |
+
Write a manifest JSON and one CSV per rollout.
|
| 231 |
+
|
| 232 |
+
- Manifest includes metadata only (safe to JSON).
|
| 233 |
+
- Each rollout CSV is written with index label 'timestep'.
|
| 234 |
+
- Only a single 'context' column (list[str]).
|
| 235 |
+
"""
|
| 236 |
+
if not self.tally or all(df.empty for df in self.tally.values()):
|
| 237 |
+
return
|
| 238 |
+
|
| 239 |
+
os.makedirs(path, exist_ok=True)
|
| 240 |
+
from datetime import datetime
|
| 241 |
+
|
| 242 |
+
now = datetime.now()
|
| 243 |
+
|
| 244 |
+
manifest = {
|
| 245 |
+
"created_at": f"{now:%Y-%m-%d %H:%M:%S}",
|
| 246 |
+
"max_context_length": self.max_context_length,
|
| 247 |
+
"num_rollouts": len(self.tally),
|
| 248 |
+
"rollouts": [],
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
for rid, df in self.tally.items():
|
| 252 |
+
rid_str = str(rid)
|
| 253 |
+
safe_name = self._sanitize_filename(rid_str)
|
| 254 |
+
csv_path = os.path.join(path, f"{safe_name}_tokenwise.csv")
|
| 255 |
+
|
| 256 |
+
# Put 'context' first, then the rest
|
| 257 |
+
cols = ["context"] + [c for c in df.columns if c != "context"]
|
| 258 |
+
try:
|
| 259 |
+
df[cols].to_csv(csv_path, index=True, index_label="timestep")
|
| 260 |
+
except Exception as e:
|
| 261 |
+
continue
|
| 262 |
+
|
| 263 |
+
manifest["rollouts"].append(
|
| 264 |
+
{
|
| 265 |
+
"rollout_id": rid_str,
|
| 266 |
+
"csv": csv_path,
|
| 267 |
+
"num_rows": int(df.shape[0]),
|
| 268 |
+
"columns": cols,
|
| 269 |
+
}
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
manifest_path = os.path.join(
|
| 273 |
+
path, f"tokenwise_manifest_{now:%Y-%m-%d___%H-%M-%S}.json"
|
| 274 |
+
)
|
| 275 |
+
with open(manifest_path, "w") as fp:
|
| 276 |
+
json.dump(manifest, fp, indent=2)
|
src_code_for_reproducibility/training/tokenize_chats.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
import regex
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import AutoTokenizer
|
| 7 |
+
|
| 8 |
+
from mllm.training.training_data_utils import TrainingChatTurn, TrajectoryBatch
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
logger.addHandler(logging.StreamHandler(sys.stdout))
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# def get_chat_dicts(chat: list[TrainingChatTurn]) -> list[dict]:
|
| 15 |
+
# chat_dicts = [chat_turn.dict() for chat_turn in chat]
|
| 16 |
+
# return chat_dicts
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def process_training_chat(
|
| 20 |
+
tokenizer: AutoTokenizer,
|
| 21 |
+
chat_history: list[TrainingChatTurn],
|
| 22 |
+
entropy_mask_regex: str | None = None,
|
| 23 |
+
exploration_prompts_to_remove: list[str] = [],
|
| 24 |
+
use_engine_out_token_ids: bool = False,
|
| 25 |
+
) -> tuple[torch.IntTensor, torch.BoolTensor, torch.IntTensor, torch.BoolTensor]:
|
| 26 |
+
"""Tokenize a single training chat and build aligned per-token masks.
|
| 27 |
+
|
| 28 |
+
Given an ordered list of `TrainingChatTurn`, this function tokenizes each
|
| 29 |
+
turn independently using the tokenizer's chat template, then concatenates
|
| 30 |
+
all resulting token sequences. It also constructs three parallel 1D masks
|
| 31 |
+
that align with the concatenated tokens:
|
| 32 |
+
|
| 33 |
+
- input_ids: token ids for the entire chat, turn by turn
|
| 34 |
+
- action_mask: True for tokens that belong to assistant turns (i.e., model
|
| 35 |
+
actions), False for tokens from other roles
|
| 36 |
+
- timesteps: per-token time step copied from the originating turn's
|
| 37 |
+
`time_step`
|
| 38 |
+
- state_ends_mask: True for the last token of any turn where
|
| 39 |
+
`is_state_end` is True, otherwise False
|
| 40 |
+
|
| 41 |
+
Important details:
|
| 42 |
+
- Each turn is passed as a single-message list to
|
| 43 |
+
`tokenizer.apply_chat_template` and flattened; the per-turn outputs are
|
| 44 |
+
then concatenated in the original order.
|
| 45 |
+
- Turn boundaries are not explicitly encoded beyond what the chat template
|
| 46 |
+
inserts; masks provide alignment for learning signals and state endings.
|
| 47 |
+
- No truncation or padding is performed here; downstream code should handle
|
| 48 |
+
batching/padding as needed.
|
| 49 |
+
- Note on dtypes: `input_ids` will be a LongTensor (int64). `action_mask`
|
| 50 |
+
and `state_ends_mask` are BoolTensors. `timesteps` is currently created
|
| 51 |
+
as a float tensor; adjust the implementation if integer dtype is
|
| 52 |
+
required downstream.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
tokenizer: A Hugging Face tokenizer supporting `apply_chat_template`.
|
| 56 |
+
chat_history: Ordered list of `TrainingChatTurn` forming one dialogue.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
A tuple of four 1D tensors, all of equal length N (the total number of
|
| 60 |
+
tokens across all turns), in the following order:
|
| 61 |
+
- input_ids (LongTensor)
|
| 62 |
+
- action_mask (BoolTensor)
|
| 63 |
+
- timesteps (FloatTensor as implemented; see note above)
|
| 64 |
+
- state_ends_mask (BoolTensor)
|
| 65 |
+
"""
|
| 66 |
+
state_ends_mask = []
|
| 67 |
+
input_ids = []
|
| 68 |
+
action_mask = []
|
| 69 |
+
timesteps = []
|
| 70 |
+
entropy_mask = []
|
| 71 |
+
engine_log_probs = []
|
| 72 |
+
for train_chat_turn in chat_history:
|
| 73 |
+
is_state_end = train_chat_turn.is_state_end
|
| 74 |
+
time_step = train_chat_turn.time_step
|
| 75 |
+
is_action = train_chat_turn.role == "assistant"
|
| 76 |
+
|
| 77 |
+
# Remove exploration prompts from training data
|
| 78 |
+
for exploration_prompt in exploration_prompts_to_remove:
|
| 79 |
+
if exploration_prompt in train_chat_turn.content:
|
| 80 |
+
train_chat_turn.content = train_chat_turn.content.replace(
|
| 81 |
+
exploration_prompt, ""
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
chat_turn = {
|
| 85 |
+
"role": train_chat_turn.role,
|
| 86 |
+
"content": train_chat_turn.content,
|
| 87 |
+
}
|
| 88 |
+
if entropy_mask_regex is not None:
|
| 89 |
+
is_entropy_mask_true = (
|
| 90 |
+
regex.search(entropy_mask_regex, train_chat_turn.content) is not None
|
| 91 |
+
)
|
| 92 |
+
else:
|
| 93 |
+
is_entropy_mask_true = True
|
| 94 |
+
if is_action:
|
| 95 |
+
chat_turn_ids = train_chat_turn.out_token_ids
|
| 96 |
+
nb_chat_turns_ids = chat_turn_ids.numel()
|
| 97 |
+
action_mask.append(torch.ones(nb_chat_turns_ids, dtype=torch.bool))
|
| 98 |
+
engine_log_probs.append(train_chat_turn.log_probs)
|
| 99 |
+
else:
|
| 100 |
+
chat_turn_ids = train_chat_turn.chat_template_token_ids
|
| 101 |
+
nb_chat_turns_ids = chat_turn_ids.numel()
|
| 102 |
+
action_mask.append(torch.zeros(nb_chat_turns_ids, dtype=torch.bool))
|
| 103 |
+
engine_log_probs.append(torch.zeros(nb_chat_turns_ids, dtype=torch.float))
|
| 104 |
+
nb_chat_turns_ids = chat_turn_ids.numel()
|
| 105 |
+
state_ends_mask.append(torch.zeros(nb_chat_turns_ids, dtype=torch.bool))
|
| 106 |
+
if is_state_end:
|
| 107 |
+
state_ends_mask[-1][-1] = True # last token is state end
|
| 108 |
+
input_ids.append(chat_turn_ids)
|
| 109 |
+
entropy_mask.append(torch.ones(nb_chat_turns_ids, dtype=torch.bool))
|
| 110 |
+
if not is_entropy_mask_true:
|
| 111 |
+
entropy_mask[-1] = entropy_mask[-1] * False
|
| 112 |
+
timesteps.append(torch.ones(nb_chat_turns_ids) * time_step)
|
| 113 |
+
input_ids = torch.cat(input_ids)
|
| 114 |
+
action_mask = torch.cat(action_mask)
|
| 115 |
+
entropy_mask = torch.cat(entropy_mask)
|
| 116 |
+
timesteps = torch.cat(timesteps)
|
| 117 |
+
timesteps = timesteps.to(torch.long)
|
| 118 |
+
state_ends_mask = torch.cat(state_ends_mask)
|
| 119 |
+
engine_log_probs = torch.cat(engine_log_probs)
|
| 120 |
+
|
| 121 |
+
return (
|
| 122 |
+
input_ids,
|
| 123 |
+
action_mask,
|
| 124 |
+
entropy_mask,
|
| 125 |
+
timesteps,
|
| 126 |
+
state_ends_mask,
|
| 127 |
+
engine_log_probs,
|
| 128 |
+
)
|
src_code_for_reproducibility/training/trainer_ad_align.py
ADDED
|
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import logging
|
| 3 |
+
import sys
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 9 |
+
|
| 10 |
+
from mllm.markov_games.rollout_tree import (
|
| 11 |
+
ChatTurn,
|
| 12 |
+
RolloutTreeBranchNode,
|
| 13 |
+
RolloutTreeRootNode,
|
| 14 |
+
)
|
| 15 |
+
from mllm.training.credit_methods import (
|
| 16 |
+
get_advantage_alignment_credits,
|
| 17 |
+
get_discounted_state_visitation_credits,
|
| 18 |
+
)
|
| 19 |
+
from mllm.training.tally_metrics import Tally
|
| 20 |
+
from mllm.training.tally_rollout import RolloutTally, RolloutTallyItem
|
| 21 |
+
from mllm.training.tally_tokenwise import ContextualizedTokenwiseTally
|
| 22 |
+
from mllm.training.tokenize_chats import process_training_chat
|
| 23 |
+
from mllm.training.trainer_common import BaseTrainer
|
| 24 |
+
from mllm.training.training_data_utils import (
|
| 25 |
+
AdvantagePacket,
|
| 26 |
+
TrainingBatch,
|
| 27 |
+
TrainingChatTurn,
|
| 28 |
+
TrajectoryBatch,
|
| 29 |
+
get_main_chat_list_and_rewards,
|
| 30 |
+
get_tokenwise_credits,
|
| 31 |
+
)
|
| 32 |
+
from mllm.utils.resource_context import resource_logger_context
|
| 33 |
+
|
| 34 |
+
logger = logging.getLogger(__name__)
|
| 35 |
+
logger.addHandler(logging.StreamHandler(sys.stdout))
|
| 36 |
+
|
| 37 |
+
RolloutId = int
|
| 38 |
+
AgentId = str
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@dataclass
|
| 42 |
+
class AdAlignTrainingData:
|
| 43 |
+
agent_id: str
|
| 44 |
+
main_data: TrajectoryBatch
|
| 45 |
+
# list-of-tensors: per rollout advantages with length jT
|
| 46 |
+
main_advantages: list[torch.FloatTensor] | None = None
|
| 47 |
+
# list-of-tensors: per rollout matrix (jT, A)
|
| 48 |
+
alternative_advantages: list[torch.FloatTensor] | None = None
|
| 49 |
+
advantage_alignment_credits: list[torch.FloatTensor] | None = None
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def get_alternative_chat_histories(
|
| 53 |
+
agent_id: str, root: RolloutTreeRootNode
|
| 54 |
+
) -> list[list[TrainingChatTurn], list[torch.FloatTensor]]:
|
| 55 |
+
"""
|
| 56 |
+
args:
|
| 57 |
+
agent_id: The agent we want to get the chat history for.
|
| 58 |
+
root: The root of the rollout tree.
|
| 59 |
+
returns:
|
| 60 |
+
alternative_chats: list[list[TrainingChatTurn]] (jT*A, jS')
|
| 61 |
+
alternative_rewards: list[torch.FloatTensor] (jT*A, jT')
|
| 62 |
+
"""
|
| 63 |
+
current_node = root.child
|
| 64 |
+
branches = current_node.branches
|
| 65 |
+
pre_branch_chat = []
|
| 66 |
+
pre_branch_rewards = []
|
| 67 |
+
alternative_rewards = []
|
| 68 |
+
alternative_chats = []
|
| 69 |
+
while current_node is not None:
|
| 70 |
+
assert isinstance(
|
| 71 |
+
current_node, RolloutTreeBranchNode
|
| 72 |
+
), "Current node should be a branch node."
|
| 73 |
+
main_node = current_node.main_child
|
| 74 |
+
branches = current_node.branches
|
| 75 |
+
current_node = main_node.child
|
| 76 |
+
|
| 77 |
+
# Get the `A` alternative trajectories
|
| 78 |
+
alternative_nodes = branches[agent_id]
|
| 79 |
+
for alt_node in alternative_nodes:
|
| 80 |
+
post_branch_chat, post_branch_rewards = get_main_chat_list_and_rewards(
|
| 81 |
+
agent_id=agent_id, root=alt_node
|
| 82 |
+
)
|
| 83 |
+
branch_chat = pre_branch_chat + post_branch_chat
|
| 84 |
+
alternative_chats.append(branch_chat)
|
| 85 |
+
alternative_rewards.append(
|
| 86 |
+
torch.cat([torch.tensor(pre_branch_rewards), post_branch_rewards])
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
chat_turns: list[ChatTurn] = main_node.step_log.action_logs[agent_id].chat_turns
|
| 90 |
+
chat_turns: list[TrainingChatTurn] = [
|
| 91 |
+
TrainingChatTurn(time_step=main_node.time_step, **turn.model_dump())
|
| 92 |
+
for turn in chat_turns
|
| 93 |
+
]
|
| 94 |
+
|
| 95 |
+
pre_branch_chat.extend(chat_turns)
|
| 96 |
+
pre_branch_rewards.append(
|
| 97 |
+
main_node.step_log.simulation_step_log.rewards[agent_id]
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
return alternative_chats, alternative_rewards
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class TrainerAdAlign(BaseTrainer):
|
| 104 |
+
"""
|
| 105 |
+
Extends the reinforce trainer to support Advantage Alignment.
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
def __init__(
|
| 109 |
+
self,
|
| 110 |
+
ad_align_beta: float,
|
| 111 |
+
ad_align_gamma: float,
|
| 112 |
+
ad_align_exclude_k_equals_t: bool,
|
| 113 |
+
ad_align_use_sign: bool,
|
| 114 |
+
ad_align_clipping: float,
|
| 115 |
+
ad_align_force_coop_first_step: bool,
|
| 116 |
+
use_old_ad_align: bool,
|
| 117 |
+
use_time_regularization: bool,
|
| 118 |
+
rloo_branch: bool,
|
| 119 |
+
reuse_baseline: bool,
|
| 120 |
+
ad_align_beta_anneal_step: int = -1,
|
| 121 |
+
ad_align_beta_anneal_rate: float = 0.5,
|
| 122 |
+
min_ad_align_beta: float = 0.1,
|
| 123 |
+
mean_normalize_ad_align: bool = False,
|
| 124 |
+
whiten_adalign_advantages: bool = False,
|
| 125 |
+
whiten_adalign_advantages_time_step_wise: bool = False,
|
| 126 |
+
*args,
|
| 127 |
+
**kwargs,
|
| 128 |
+
):
|
| 129 |
+
"""
|
| 130 |
+
Initialize the advantage alignment trainer.
|
| 131 |
+
Args:
|
| 132 |
+
ad_align_beta: Beta parameter for the advantage alignment.
|
| 133 |
+
ad_align_gamma: Gamma parameter for the advantage alignment.
|
| 134 |
+
ad_align_exclude_k_equals_t: Whether to include k = t in the advantage alignment.
|
| 135 |
+
ad_align_use_sign: Whether to use sign in the advantage alignment.
|
| 136 |
+
ad_align_clipping: Clipping value for the advantage alignment.
|
| 137 |
+
ad_align_force_coop_first_step: Whether to force coop on the first step of the advantage alignment.
|
| 138 |
+
"""
|
| 139 |
+
super().__init__(*args, **kwargs)
|
| 140 |
+
self.ad_align_beta = ad_align_beta
|
| 141 |
+
self.ad_align_gamma = ad_align_gamma
|
| 142 |
+
self.ad_align_exclude_k_equals_t = ad_align_exclude_k_equals_t
|
| 143 |
+
self.ad_align_use_sign = ad_align_use_sign
|
| 144 |
+
self.ad_align_clipping = ad_align_clipping
|
| 145 |
+
self.ad_align_force_coop_first_step = ad_align_force_coop_first_step
|
| 146 |
+
self.use_old_ad_align = use_old_ad_align
|
| 147 |
+
self.use_time_regularization = use_time_regularization
|
| 148 |
+
self.rloo_branch = rloo_branch
|
| 149 |
+
self.reuse_baseline = reuse_baseline
|
| 150 |
+
self.ad_align_beta_anneal_step = ad_align_beta_anneal_step
|
| 151 |
+
self.ad_align_beta_anneal_rate = ad_align_beta_anneal_rate
|
| 152 |
+
self.min_ad_align_beta = min_ad_align_beta
|
| 153 |
+
self.past_ad_align_step = -1
|
| 154 |
+
self.mean_normalize_ad_align = mean_normalize_ad_align
|
| 155 |
+
self.whiten_adalign_advantages = whiten_adalign_advantages
|
| 156 |
+
self.whiten_adalign_advantages_time_step_wise = (
|
| 157 |
+
whiten_adalign_advantages_time_step_wise
|
| 158 |
+
)
|
| 159 |
+
self.training_data: dict[AgentId, AdAlignTrainingData] = {}
|
| 160 |
+
self.debug_path_list: list[str] = []
|
| 161 |
+
|
| 162 |
+
def set_agent_trajectory_data(
|
| 163 |
+
self, agent_id: str, roots: list[RolloutTreeRootNode]
|
| 164 |
+
):
|
| 165 |
+
"""
|
| 166 |
+
TOWRITE
|
| 167 |
+
Set the advantage alignment data for the trainer.
|
| 168 |
+
"""
|
| 169 |
+
|
| 170 |
+
B = len(roots) # Number of rollouts
|
| 171 |
+
|
| 172 |
+
# For main rollouts
|
| 173 |
+
batch_rollout_ids = []
|
| 174 |
+
batch_crn_ids = []
|
| 175 |
+
batch_input_ids = []
|
| 176 |
+
batch_action_mask = []
|
| 177 |
+
batch_entropy_mask = []
|
| 178 |
+
batch_timesteps = []
|
| 179 |
+
batch_state_ends_mask = []
|
| 180 |
+
batch_engine_log_probs = []
|
| 181 |
+
batch_rewards = []
|
| 182 |
+
|
| 183 |
+
# For alternative actions rollouts
|
| 184 |
+
batch_branching_time_steps = []
|
| 185 |
+
alternative_batch_input_ids = []
|
| 186 |
+
alternative_batch_action_mask = []
|
| 187 |
+
alternative_batch_entropy_mask = []
|
| 188 |
+
alternative_batch_timesteps = []
|
| 189 |
+
alternative_batch_state_ends_mask = []
|
| 190 |
+
alternative_batch_engine_log_probs = []
|
| 191 |
+
alternative_batch_rewards = []
|
| 192 |
+
jT_list = []
|
| 193 |
+
|
| 194 |
+
try:
|
| 195 |
+
A = len(roots[0].child.branches[agent_id]) # Number of alternative actions
|
| 196 |
+
except:
|
| 197 |
+
A = 0
|
| 198 |
+
|
| 199 |
+
for root in roots:
|
| 200 |
+
rollout_id = root.id
|
| 201 |
+
self.debug_path_list.append(
|
| 202 |
+
"mgid:" + str(rollout_id) + "_agent_id:" + agent_id
|
| 203 |
+
)
|
| 204 |
+
# Get main trajectory
|
| 205 |
+
batch_rollout_ids.append(rollout_id)
|
| 206 |
+
batch_crn_ids.append(root.crn_id)
|
| 207 |
+
main_chat, main_rewards = get_main_chat_list_and_rewards(
|
| 208 |
+
agent_id=agent_id, root=root
|
| 209 |
+
)
|
| 210 |
+
(
|
| 211 |
+
input_ids,
|
| 212 |
+
action_mask,
|
| 213 |
+
entropy_mask,
|
| 214 |
+
timesteps,
|
| 215 |
+
state_ends_mask,
|
| 216 |
+
engine_log_probs,
|
| 217 |
+
) = process_training_chat(
|
| 218 |
+
tokenizer=self.tokenizer,
|
| 219 |
+
chat_history=main_chat,
|
| 220 |
+
entropy_mask_regex=self.entropy_mask_regex,
|
| 221 |
+
exploration_prompts_to_remove=self.exploration_prompts_to_remove,
|
| 222 |
+
)
|
| 223 |
+
batch_input_ids.append(input_ids)
|
| 224 |
+
batch_action_mask.append(action_mask)
|
| 225 |
+
batch_entropy_mask.append(entropy_mask)
|
| 226 |
+
batch_timesteps.append(timesteps)
|
| 227 |
+
batch_state_ends_mask.append(state_ends_mask)
|
| 228 |
+
batch_engine_log_probs.append(engine_log_probs)
|
| 229 |
+
batch_rewards.append(main_rewards)
|
| 230 |
+
jT = main_rewards.numel() # TODO: better than this
|
| 231 |
+
jT_list.append(jT)
|
| 232 |
+
if A > 0:
|
| 233 |
+
# We get the branching time steps for each of the `jT` time steps in the main trajectory.
|
| 234 |
+
branching_time_steps = [bt for item in range(jT) for bt in A * [item]]
|
| 235 |
+
batch_branching_time_steps.extend(branching_time_steps)
|
| 236 |
+
|
| 237 |
+
# Get all of the (jT*A) alternative trajectories in the tree
|
| 238 |
+
# (jT is the number of time steps in the main trajectory, A is the number of alternative actions)
|
| 239 |
+
alternative_chats, alternative_rewards = get_alternative_chat_histories(
|
| 240 |
+
agent_id=agent_id, root=root
|
| 241 |
+
)
|
| 242 |
+
assert (
|
| 243 |
+
len(alternative_chats) == A * jT
|
| 244 |
+
), "Incorrect number of alternative trajectories."
|
| 245 |
+
|
| 246 |
+
for chat, rewards in zip(alternative_chats, alternative_rewards):
|
| 247 |
+
(
|
| 248 |
+
input_ids,
|
| 249 |
+
action_mask,
|
| 250 |
+
entropy_mask,
|
| 251 |
+
timesteps,
|
| 252 |
+
state_ends_mask,
|
| 253 |
+
engine_log_probs,
|
| 254 |
+
) = process_training_chat(
|
| 255 |
+
tokenizer=self.tokenizer,
|
| 256 |
+
chat_history=chat,
|
| 257 |
+
entropy_mask_regex=self.entropy_mask_regex,
|
| 258 |
+
exploration_prompts_to_remove=self.exploration_prompts_to_remove,
|
| 259 |
+
)
|
| 260 |
+
alternative_batch_input_ids.append(input_ids)
|
| 261 |
+
alternative_batch_action_mask.append(action_mask)
|
| 262 |
+
alternative_batch_entropy_mask.append(entropy_mask)
|
| 263 |
+
alternative_batch_timesteps.append(timesteps)
|
| 264 |
+
alternative_batch_state_ends_mask.append(state_ends_mask)
|
| 265 |
+
alternative_batch_engine_log_probs.append(engine_log_probs)
|
| 266 |
+
alternative_batch_rewards.append(rewards)
|
| 267 |
+
|
| 268 |
+
jT_list = torch.Tensor(jT_list)
|
| 269 |
+
|
| 270 |
+
# Assert that number of alternative actions is constant
|
| 271 |
+
# assert len(set(nb_alternative_actions)) == 1, "Number of alternative actions must be constant"
|
| 272 |
+
# A = nb_alternative_actions[0]
|
| 273 |
+
|
| 274 |
+
trajectory_batch = TrajectoryBatch(
|
| 275 |
+
rollout_ids=torch.tensor(batch_rollout_ids, dtype=torch.int32), # (B,)
|
| 276 |
+
crn_ids=torch.tensor(batch_crn_ids, dtype=torch.int32),
|
| 277 |
+
agent_ids=[agent_id] * len(batch_rollout_ids),
|
| 278 |
+
batch_input_ids=batch_input_ids,
|
| 279 |
+
batch_action_mask=batch_action_mask,
|
| 280 |
+
batch_entropy_mask=batch_entropy_mask,
|
| 281 |
+
batch_timesteps=batch_timesteps,
|
| 282 |
+
batch_state_ends_mask=batch_state_ends_mask,
|
| 283 |
+
batch_engine_log_probs=batch_engine_log_probs,
|
| 284 |
+
batch_rewards=batch_rewards,
|
| 285 |
+
)
|
| 286 |
+
# Get Advantages & Train Critic
|
| 287 |
+
with resource_logger_context(
|
| 288 |
+
logger, "Get advantages with critic gradient accumulation"
|
| 289 |
+
):
|
| 290 |
+
self.batch_advantages: torch.FloatTensor = (
|
| 291 |
+
self.get_advantages_with_critic_gradient_accumulation(trajectory_batch)
|
| 292 |
+
) # (B, jT)
|
| 293 |
+
|
| 294 |
+
if A > 0:
|
| 295 |
+
# Here, `A` is the number of alternative actions / trajectories taken at each time step.
|
| 296 |
+
# For each of the `B` rollout perspectives, at each of its jT (`j` is for jagged, since each main rollout may be of a different length) steps, we take A alternate trajectories (from different actions).
|
| 297 |
+
# Therefore, we have ∑jT * A trajectories to process. If each of the main trajectories have T steps, we will have `B*T*A` to process.
|
| 298 |
+
with resource_logger_context(logger, "Create alternative trajectory batch"):
|
| 299 |
+
sum_jT = int(torch.sum(jT_list).item())
|
| 300 |
+
jT_list = (
|
| 301 |
+
jT_list.int().tolist()
|
| 302 |
+
) # (jT,) # (we only want the advantages where we branched out)
|
| 303 |
+
alternative_trajectory_batch = TrajectoryBatch(
|
| 304 |
+
rollout_ids=torch.zeros(A * sum_jT, dtype=torch.int32),
|
| 305 |
+
crn_ids=torch.zeros(A * sum_jT, dtype=torch.int32),
|
| 306 |
+
agent_ids=[agent_id] * (A * sum_jT),
|
| 307 |
+
batch_input_ids=alternative_batch_input_ids,
|
| 308 |
+
batch_action_mask=alternative_batch_action_mask,
|
| 309 |
+
batch_entropy_mask=alternative_batch_entropy_mask,
|
| 310 |
+
batch_timesteps=alternative_batch_timesteps,
|
| 311 |
+
batch_state_ends_mask=alternative_batch_state_ends_mask,
|
| 312 |
+
batch_engine_log_probs=alternative_batch_engine_log_probs,
|
| 313 |
+
batch_rewards=alternative_batch_rewards,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
# Get alternative advantages
|
| 317 |
+
# BAAs stands for batch alternative advantages
|
| 318 |
+
# (torch nested tensors have very little api support, so we have to do some odd manual work here)
|
| 319 |
+
with resource_logger_context(
|
| 320 |
+
logger, "Compute alternative advantage estimates"
|
| 321 |
+
):
|
| 322 |
+
BAAs_list = self.get_advantages_with_critic_gradient_accumulation(
|
| 323 |
+
alternative_trajectory_batch
|
| 324 |
+
) # list length (∑jT * A), each (jT',)
|
| 325 |
+
# Pad alternative advantages to (∑jT*A, P)
|
| 326 |
+
|
| 327 |
+
BAAs_padded = pad_sequence(
|
| 328 |
+
BAAs_list, batch_first=True, padding_value=0.0
|
| 329 |
+
)
|
| 330 |
+
branch_idx = torch.tensor(
|
| 331 |
+
batch_branching_time_steps,
|
| 332 |
+
device=BAAs_padded.device,
|
| 333 |
+
dtype=torch.long,
|
| 334 |
+
)
|
| 335 |
+
gathered = BAAs_padded.gather(
|
| 336 |
+
dim=1, index=branch_idx.unsqueeze(1)
|
| 337 |
+
).squeeze(1)
|
| 338 |
+
# Reshape and split per rollout, then transpose to (jT_i, A)
|
| 339 |
+
gathered = gathered.view(A, sum_jT) # (A, ∑jT)
|
| 340 |
+
blocks = list(
|
| 341 |
+
torch.split(gathered, jT_list, dim=1)
|
| 342 |
+
) # len B, shapes (A, jT_i)
|
| 343 |
+
BAAs = [
|
| 344 |
+
blk.transpose(0, 1).contiguous() for blk in blocks
|
| 345 |
+
] # list of (jT_i, A)
|
| 346 |
+
if self.ad_align_beta_anneal_step > 0:
|
| 347 |
+
max_rollout_id = torch.max(trajectory_batch.rollout_ids) + 1
|
| 348 |
+
if (
|
| 349 |
+
max_rollout_id % self.ad_align_beta_anneal_step == 0
|
| 350 |
+
and self.past_ad_align_step != max_rollout_id
|
| 351 |
+
):
|
| 352 |
+
self.ad_align_beta = max(
|
| 353 |
+
self.ad_align_beta * self.ad_align_beta_anneal_rate,
|
| 354 |
+
self.min_ad_align_beta,
|
| 355 |
+
)
|
| 356 |
+
logger.info(f"Annealing ad_align_beta to {self.ad_align_beta}")
|
| 357 |
+
self.past_ad_align_step = max_rollout_id
|
| 358 |
+
self.training_data[agent_id] = AdAlignTrainingData(
|
| 359 |
+
agent_id=agent_id,
|
| 360 |
+
main_data=trajectory_batch,
|
| 361 |
+
main_advantages=self.batch_advantages,
|
| 362 |
+
alternative_advantages=BAAs if A > 0 else None,
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
def share_advantage_data(self) -> list[AdvantagePacket]:
|
| 366 |
+
"""
|
| 367 |
+
Share the advantage alignment data with other agents.
|
| 368 |
+
Returns:
|
| 369 |
+
AdvantagePacket: The advantage packet containing the agent's advantages.
|
| 370 |
+
"""
|
| 371 |
+
logger.info(f"Sharing advantage alignment data.")
|
| 372 |
+
advantage_packets = []
|
| 373 |
+
for _, agent_data in self.training_data.items():
|
| 374 |
+
advantage_packets.append(
|
| 375 |
+
AdvantagePacket(
|
| 376 |
+
agent_id=agent_data.agent_id,
|
| 377 |
+
rollout_ids=agent_data.main_data.rollout_ids,
|
| 378 |
+
main_advantages=agent_data.main_advantages,
|
| 379 |
+
)
|
| 380 |
+
)
|
| 381 |
+
return advantage_packets
|
| 382 |
+
|
| 383 |
+
def receive_advantage_data(self, advantage_packets: list[AdvantagePacket]):
|
| 384 |
+
"""
|
| 385 |
+
Receive advantage packets from other players.
|
| 386 |
+
These contain the advantages of the other players' rollouts estimated by them.
|
| 387 |
+
"""
|
| 388 |
+
logger.info(f"Receiving advantage packets.")
|
| 389 |
+
|
| 390 |
+
assert (
|
| 391 |
+
len(advantage_packets) > 0
|
| 392 |
+
), "At least one advantage packet must be provided."
|
| 393 |
+
|
| 394 |
+
for agent_id, agent_data in self.training_data.items():
|
| 395 |
+
coagent_advantage_packets = [
|
| 396 |
+
packet for packet in advantage_packets if packet.agent_id != agent_id
|
| 397 |
+
]
|
| 398 |
+
agent_rollout_ids = agent_data.main_data.rollout_ids
|
| 399 |
+
agent_advantages = agent_data.main_advantages
|
| 400 |
+
co_agent_advantages = []
|
| 401 |
+
for rollout_id in agent_rollout_ids:
|
| 402 |
+
for co_agent_packet in coagent_advantage_packets:
|
| 403 |
+
if rollout_id in co_agent_packet.rollout_ids:
|
| 404 |
+
index = torch.where(rollout_id == co_agent_packet.rollout_ids)[
|
| 405 |
+
0
|
| 406 |
+
].item()
|
| 407 |
+
co_agent_advantages.append(
|
| 408 |
+
co_agent_packet.main_advantages[index]
|
| 409 |
+
)
|
| 410 |
+
# assumes that its two player game, with one co-agent
|
| 411 |
+
break
|
| 412 |
+
assert len(co_agent_advantages) == len(agent_advantages)
|
| 413 |
+
B = len(agent_advantages)
|
| 414 |
+
assert all(
|
| 415 |
+
a.shape[0] == b.shape[0]
|
| 416 |
+
for a, b in zip(co_agent_advantages, agent_advantages)
|
| 417 |
+
), "Number of advantages must match for advantage alignment."
|
| 418 |
+
|
| 419 |
+
# Get padded tensors (advantage alignment is invariant to padding)
|
| 420 |
+
lengths = torch.tensor(
|
| 421 |
+
[len(t) for t in agent_advantages],
|
| 422 |
+
device=self.device,
|
| 423 |
+
dtype=torch.long,
|
| 424 |
+
)
|
| 425 |
+
padded_main_advantages = pad_sequence(
|
| 426 |
+
agent_advantages, batch_first=True, padding_value=0.0
|
| 427 |
+
)
|
| 428 |
+
if agent_data.alternative_advantages:
|
| 429 |
+
padded_alternative_advantages = pad_sequence(
|
| 430 |
+
agent_data.alternative_advantages,
|
| 431 |
+
batch_first=True,
|
| 432 |
+
padding_value=0.0,
|
| 433 |
+
) # (B, P, A)
|
| 434 |
+
else:
|
| 435 |
+
padded_alternative_advantages = None
|
| 436 |
+
padded_co_agent_advantages = pad_sequence(
|
| 437 |
+
co_agent_advantages, batch_first=True, padding_value=0.0
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
# Create training batch data
|
| 441 |
+
credits, sub_tensors = get_advantage_alignment_credits(
|
| 442 |
+
a1=padded_main_advantages,
|
| 443 |
+
a1_alternative=padded_alternative_advantages,
|
| 444 |
+
a2=padded_co_agent_advantages,
|
| 445 |
+
beta=self.ad_align_beta,
|
| 446 |
+
gamma=self.ad_align_gamma,
|
| 447 |
+
exclude_k_equals_t=self.ad_align_exclude_k_equals_t,
|
| 448 |
+
use_sign=self.ad_align_use_sign,
|
| 449 |
+
clipping=self.ad_align_clipping,
|
| 450 |
+
force_coop_first_step=self.ad_align_force_coop_first_step,
|
| 451 |
+
use_old_ad_align=self.use_old_ad_align,
|
| 452 |
+
use_time_regularization=self.use_time_regularization,
|
| 453 |
+
rloo_branch=self.rloo_branch,
|
| 454 |
+
reuse_baseline=self.reuse_baseline,
|
| 455 |
+
mean_normalize_ad_align=self.mean_normalize_ad_align,
|
| 456 |
+
whiten_adalign_advantages=self.whiten_adalign_advantages,
|
| 457 |
+
whiten_adalign_advantages_time_step_wise=self.whiten_adalign_advantages_time_step_wise,
|
| 458 |
+
)
|
| 459 |
+
for key, value in sub_tensors.items():
|
| 460 |
+
self.rollout_tally.add_metric(
|
| 461 |
+
path=[key],
|
| 462 |
+
rollout_tally_item=RolloutTallyItem(
|
| 463 |
+
crn_ids=agent_data.main_data.crn_ids,
|
| 464 |
+
rollout_ids=agent_data.main_data.rollout_ids,
|
| 465 |
+
agent_ids=agent_data.main_data.agent_ids,
|
| 466 |
+
metric_matrix=value,
|
| 467 |
+
),
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
if not self.skip_discounted_state_visitation:
|
| 471 |
+
credits = get_discounted_state_visitation_credits(
|
| 472 |
+
credits,
|
| 473 |
+
self.discount_factor,
|
| 474 |
+
)
|
| 475 |
+
self.rollout_tally.add_metric(
|
| 476 |
+
path=["discounted_state_visitation_credits"],
|
| 477 |
+
rollout_tally_item=RolloutTallyItem(
|
| 478 |
+
crn_ids=agent_data.main_data.crn_ids,
|
| 479 |
+
rollout_ids=agent_data.main_data.rollout_ids,
|
| 480 |
+
agent_ids=agent_data.main_data.agent_ids,
|
| 481 |
+
metric_matrix=sub_tensors[
|
| 482 |
+
"discounted_state_visitation_credits"
|
| 483 |
+
],
|
| 484 |
+
),
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
# Slice back to jagged
|
| 488 |
+
advantage_alignment_credits = [credits[i, : lengths[i]] for i in range(B)]
|
| 489 |
+
# Replace stored training data for this agent by the concrete trajectory batch
|
| 490 |
+
# and attach the computed credits for policy gradient.
|
| 491 |
+
self.training_data[agent_id] = agent_data.main_data
|
| 492 |
+
self.training_data[agent_id].batch_credits = advantage_alignment_credits
|
src_code_for_reproducibility/training/trainer_common.py
ADDED
|
@@ -0,0 +1,1054 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TODO: Add coefficients for losses (depend on total number of tokens or batch)
|
| 3 |
+
TODO: adapt reinforce step for torch.compile
|
| 4 |
+
TODO: add lr schedulers support
|
| 5 |
+
"""
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
import pickle
|
| 9 |
+
import sys
|
| 10 |
+
from abc import ABC, abstractmethod
|
| 11 |
+
from typing import Callable, Literal, Union
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
from accelerate import Accelerator
|
| 17 |
+
from pandas._libs.tslibs.offsets import CBMonthBegin
|
| 18 |
+
from peft import LoraConfig
|
| 19 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 20 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 21 |
+
|
| 22 |
+
from mllm.markov_games.rollout_tree import *
|
| 23 |
+
from mllm.markov_games.rollout_tree import RolloutTreeRootNode
|
| 24 |
+
from mllm.training.annealing_methods import sigmoid_annealing
|
| 25 |
+
from mllm.training.credit_methods import (
|
| 26 |
+
get_discounted_returns,
|
| 27 |
+
get_generalized_advantage_estimates,
|
| 28 |
+
get_rloo_credits,
|
| 29 |
+
whiten_advantages,
|
| 30 |
+
whiten_advantages_time_step_wise,
|
| 31 |
+
)
|
| 32 |
+
from mllm.training.tally_metrics import Tally
|
| 33 |
+
from mllm.training.tally_rollout import RolloutTally, RolloutTallyItem
|
| 34 |
+
from mllm.training.tally_tokenwise import ContextualizedTokenwiseTally
|
| 35 |
+
from mllm.training.tokenize_chats import *
|
| 36 |
+
from mllm.training.tokenize_chats import process_training_chat
|
| 37 |
+
from mllm.training.training_data_utils import *
|
| 38 |
+
from mllm.training.training_data_utils import (
|
| 39 |
+
TrainingBatch,
|
| 40 |
+
TrajectoryBatch,
|
| 41 |
+
get_tokenwise_credits,
|
| 42 |
+
)
|
| 43 |
+
from mllm.utils.resource_context import resource_logger_context
|
| 44 |
+
|
| 45 |
+
logger = logging.getLogger(__name__)
|
| 46 |
+
logger.addHandler(logging.StreamHandler(sys.stdout))
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@dataclass
|
| 50 |
+
class TrainerAnnealingState:
|
| 51 |
+
annealing_step_counter: int = 0
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class BaseTrainer(ABC):
|
| 55 |
+
"""
|
| 56 |
+
Trainer
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
policy: AutoModelForCausalLM,
|
| 62 |
+
policy_optimizer: torch.optim.Optimizer,
|
| 63 |
+
critic: Union[AutoModelForCausalLM, None],
|
| 64 |
+
critic_optimizer: Union[torch.optim.Optimizer, None],
|
| 65 |
+
tokenizer: AutoTokenizer,
|
| 66 |
+
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
|
| 67 |
+
critic_lr_scheduler: Union[torch.optim.lr_scheduler.LRScheduler, None],
|
| 68 |
+
######################################################################
|
| 69 |
+
entropy_coeff: float,
|
| 70 |
+
entropy_topk: int,
|
| 71 |
+
entropy_mask_regex: Union[str, None],
|
| 72 |
+
kl_coeff: float,
|
| 73 |
+
gradient_clipping: Union[float, None],
|
| 74 |
+
restrict_tokens: Union[list[str], None],
|
| 75 |
+
mini_batch_size: int,
|
| 76 |
+
use_gradient_checkpointing: bool,
|
| 77 |
+
temperature: float,
|
| 78 |
+
device: str,
|
| 79 |
+
whiten_advantages: bool,
|
| 80 |
+
whiten_advantages_time_step_wise: bool,
|
| 81 |
+
use_gae: bool,
|
| 82 |
+
use_gae_lambda_annealing: bool,
|
| 83 |
+
gae_lambda_annealing_limit: float,
|
| 84 |
+
gae_lambda_annealing_method: Literal["sigmoid_annealing"],
|
| 85 |
+
gae_lambda_annealing_method_params: dict,
|
| 86 |
+
pg_loss_normalization: Literal["batch", "nb_tokens"],
|
| 87 |
+
use_rloo: bool,
|
| 88 |
+
skip_discounted_state_visitation: bool,
|
| 89 |
+
discount_factor: float,
|
| 90 |
+
enable_tokenwise_logging: bool,
|
| 91 |
+
save_path: str,
|
| 92 |
+
reward_normalizing_constant: float = 1.0,
|
| 93 |
+
critic_loss_type: Literal["mse", "huber"] = "huber",
|
| 94 |
+
exploration_prompts_to_remove: list[str] = [],
|
| 95 |
+
filter_higher_refprob_tokens_kl: bool = False,
|
| 96 |
+
truncated_importance_sampling_ratio_cap: float = 0.0,
|
| 97 |
+
importance_sampling_strategy: Literal[
|
| 98 |
+
"per_token", "per_sequence"
|
| 99 |
+
] = "per_token",
|
| 100 |
+
):
|
| 101 |
+
"""
|
| 102 |
+
Initialize the REINFORCE trainer with reward shaping for multi-agent or single-agent training.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
model (AutoModelForCausalLM): The main policy model.
|
| 106 |
+
tokenizer (AutoTokenizer): Tokenizer for the model.
|
| 107 |
+
optimizer (torch.optim.Optimizer): Optimizer for the policy model.
|
| 108 |
+
lr_scheduler (torch.optim.lr_scheduler.LRScheduler): Learning rate scheduler for the policy model.
|
| 109 |
+
critic (AutoModelForCausalLM or None): Critic model for value estimation (optional).
|
| 110 |
+
critic_optimizer (torch.optim.Optimizer or None): Optimizer for the critic model (optional).
|
| 111 |
+
critic_lr_scheduler (torch.optim.lr_scheduler.LRScheduler or None): LR scheduler for the critic (optional).
|
| 112 |
+
config (RtConfig): Configuration object for training.
|
| 113 |
+
"""
|
| 114 |
+
self.tokenizer = tokenizer
|
| 115 |
+
# self.tokenizer.padding_side = "left" # needed for flash attention
|
| 116 |
+
if self.tokenizer.pad_token_id is None:
|
| 117 |
+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
| 118 |
+
self.lr_scheduler = lr_scheduler
|
| 119 |
+
self.accelerator = Accelerator()
|
| 120 |
+
(
|
| 121 |
+
self.policy,
|
| 122 |
+
self.policy_optimizer,
|
| 123 |
+
self.critic,
|
| 124 |
+
self.critic_optimizer,
|
| 125 |
+
) = self.accelerator.prepare(policy, policy_optimizer, critic, critic_optimizer)
|
| 126 |
+
|
| 127 |
+
self.critic_lr_scheduler = critic_lr_scheduler
|
| 128 |
+
self.tally = Tally()
|
| 129 |
+
|
| 130 |
+
if use_gradient_checkpointing == True:
|
| 131 |
+
self.policy.gradient_checkpointing_enable(dict(use_reentrant=False))
|
| 132 |
+
if critic is not None:
|
| 133 |
+
self.critic.gradient_checkpointing_enable(dict(use_reentrant=False))
|
| 134 |
+
|
| 135 |
+
self.save_path = save_path
|
| 136 |
+
|
| 137 |
+
# Load trainer state if it exists
|
| 138 |
+
self.trainer_annealing_state_path = os.path.join(
|
| 139 |
+
self.save_path, "trainer_annealing_state.pkl"
|
| 140 |
+
)
|
| 141 |
+
if os.path.exists(self.trainer_annealing_state_path):
|
| 142 |
+
logger.info(
|
| 143 |
+
f"Loading trainer state from {self.trainer_annealing_state_path}"
|
| 144 |
+
)
|
| 145 |
+
self.trainer_annealing_state = pickle.load(
|
| 146 |
+
open(self.trainer_annealing_state_path, "rb")
|
| 147 |
+
)
|
| 148 |
+
else:
|
| 149 |
+
self.trainer_annealing_state = TrainerAnnealingState()
|
| 150 |
+
|
| 151 |
+
# Load policy optimizer state if it exists
|
| 152 |
+
self.policy_optimizer_path = os.path.join(
|
| 153 |
+
self.save_path, "policy_optimizer_state.pt"
|
| 154 |
+
)
|
| 155 |
+
if os.path.exists(self.policy_optimizer_path):
|
| 156 |
+
logger.info(
|
| 157 |
+
f"Loading policy optimizer state from {self.policy_optimizer_path}"
|
| 158 |
+
)
|
| 159 |
+
self.policy_optimizer.load_state_dict(
|
| 160 |
+
torch.load(self.policy_optimizer_path)
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# Load critic optimizer state if it exists
|
| 164 |
+
self.critic_optimizer_path = os.path.join(
|
| 165 |
+
self.save_path, "critic_optimizer_state.pt"
|
| 166 |
+
)
|
| 167 |
+
if (
|
| 168 |
+
os.path.exists(self.critic_optimizer_path)
|
| 169 |
+
and self.critic_optimizer is not None
|
| 170 |
+
):
|
| 171 |
+
logger.info(
|
| 172 |
+
f"Loading critic optimizer state from {self.critic_optimizer_path}"
|
| 173 |
+
)
|
| 174 |
+
self.critic_optimizer.load_state_dict(
|
| 175 |
+
torch.load(self.critic_optimizer_path)
|
| 176 |
+
)
|
| 177 |
+
self.device = self.accelerator.device
|
| 178 |
+
self.entropy_coeff = entropy_coeff
|
| 179 |
+
self.entropy_topk = entropy_topk
|
| 180 |
+
self.entropy_mask_regex = entropy_mask_regex
|
| 181 |
+
self.kl_coeff = kl_coeff
|
| 182 |
+
self.gradient_clipping = gradient_clipping
|
| 183 |
+
self.restrict_tokens = restrict_tokens
|
| 184 |
+
self.mini_batch_size = mini_batch_size
|
| 185 |
+
self.use_gradient_checkpointing = use_gradient_checkpointing
|
| 186 |
+
self.temperature = temperature
|
| 187 |
+
self.use_gae = use_gae
|
| 188 |
+
self.whiten_advantages = whiten_advantages
|
| 189 |
+
self.whiten_advantages_time_step_wise = whiten_advantages_time_step_wise
|
| 190 |
+
self.use_rloo = use_rloo
|
| 191 |
+
self.skip_discounted_state_visitation = skip_discounted_state_visitation
|
| 192 |
+
self.use_gae_lambda_annealing = use_gae_lambda_annealing
|
| 193 |
+
self.gae_lambda_annealing_limit = gae_lambda_annealing_limit
|
| 194 |
+
if use_gae_lambda_annealing:
|
| 195 |
+
self.gae_lambda_annealing_method: Callable[
|
| 196 |
+
[int], float
|
| 197 |
+
] = lambda step: eval(gae_lambda_annealing_method)(
|
| 198 |
+
step=step, **gae_lambda_annealing_method_params
|
| 199 |
+
)
|
| 200 |
+
self.discount_factor = discount_factor
|
| 201 |
+
self.enable_tokenwise_logging = enable_tokenwise_logging
|
| 202 |
+
self.reward_normalizing_constant = reward_normalizing_constant
|
| 203 |
+
self.pg_loss_normalization = pg_loss_normalization
|
| 204 |
+
self.critic_loss_type = critic_loss_type
|
| 205 |
+
self.exploration_prompts_to_remove = exploration_prompts_to_remove
|
| 206 |
+
# Common containers used by all trainers
|
| 207 |
+
self.training_data: dict = {}
|
| 208 |
+
self.debug_path_list: list[str] = []
|
| 209 |
+
self.policy_gradient_data = None
|
| 210 |
+
self.tally = Tally()
|
| 211 |
+
self.rollout_tally = RolloutTally()
|
| 212 |
+
self.tokenwise_tally: Union[ContextualizedTokenwiseTally, None] = None
|
| 213 |
+
self.filter_higher_refprob_tokens_kl = filter_higher_refprob_tokens_kl
|
| 214 |
+
self.truncated_importance_sampling_ratio_cap = (
|
| 215 |
+
truncated_importance_sampling_ratio_cap
|
| 216 |
+
)
|
| 217 |
+
self.importance_sampling_strategy = importance_sampling_strategy
|
| 218 |
+
|
| 219 |
+
def mask_non_restricted_token_logits(self, logits: torch.Tensor) -> torch.Tensor:
|
| 220 |
+
"""
|
| 221 |
+
Masks logits so that only allowed tokens (as specified in config.restrict_tokens)
|
| 222 |
+
and the EOS token are active.
|
| 223 |
+
All other logits are set to -inf, effectively removing them from the softmax.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
logits (torch.Tensor): The logits tensor of shape (B, S, V).
|
| 227 |
+
|
| 228 |
+
Returns:
|
| 229 |
+
torch.Tensor: The masked logits tensor.
|
| 230 |
+
"""
|
| 231 |
+
# TODO: verify. Not sure what we do here is differentiable
|
| 232 |
+
# also, we recompute for nothing
|
| 233 |
+
|
| 234 |
+
if self.restrict_tokens is not None:
|
| 235 |
+
allowed_token_ids = []
|
| 236 |
+
for token in self.restrict_tokens:
|
| 237 |
+
token_ids = self.tokenizer(token, add_special_tokens=False)["input_ids"]
|
| 238 |
+
allowed_token_ids.append(token_ids[0])
|
| 239 |
+
allowed_token_ids.append(
|
| 240 |
+
self.tokenizer.eos_token_id
|
| 241 |
+
) # This token should always be active
|
| 242 |
+
allowed_token_ids = torch.tensor(allowed_token_ids, device=logits.device)
|
| 243 |
+
# Mask log_probs and probs to only allowed tokens
|
| 244 |
+
mask = torch.zeros_like(logits).bool() # (B, S, V)
|
| 245 |
+
mask[..., allowed_token_ids] = True
|
| 246 |
+
logits = torch.where(
|
| 247 |
+
mask,
|
| 248 |
+
logits,
|
| 249 |
+
torch.tensor(-float("inf"), device=logits.device),
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
return logits
|
| 253 |
+
|
| 254 |
+
# def get_gradient_magnitude(self, loss_term: torch.Tensor) -> float:
|
| 255 |
+
# """
|
| 256 |
+
# Computes the L2 norm of the gradients of the given loss term with respect to the model parameters.
|
| 257 |
+
|
| 258 |
+
# Args:
|
| 259 |
+
# loss_term (torch.Tensor): The loss tensor to compute gradients for.
|
| 260 |
+
|
| 261 |
+
# Returns:
|
| 262 |
+
# float: The L2 norm of the gradients, or 0.0 if no gradients are present.
|
| 263 |
+
# """
|
| 264 |
+
# with torch.no_grad():
|
| 265 |
+
# grads = torch.autograd.grad(
|
| 266 |
+
# loss_term,
|
| 267 |
+
# [p for p in self.policy.parameters() if p.requires_grad],
|
| 268 |
+
# retain_graph=True,
|
| 269 |
+
# allow_unused=True,
|
| 270 |
+
# )
|
| 271 |
+
# grads = [g for g in grads if g is not None]
|
| 272 |
+
# if not grads:
|
| 273 |
+
# return torch.tensor(0.0, device=loss_term.device)
|
| 274 |
+
# return torch.norm(torch.stack([g.norm(2) for g in grads])).item()
|
| 275 |
+
|
| 276 |
+
def apply_reinforce_step(
|
| 277 |
+
self,
|
| 278 |
+
training_batch: TrainingBatch,
|
| 279 |
+
) -> None:
|
| 280 |
+
"""
|
| 281 |
+
Applies a single REINFORCE policy gradient step using the provided batch of rollouts.
|
| 282 |
+
Handles batching, loss computation (including entropy and KL regularization), gradient accumulation, and optimizer step.
|
| 283 |
+
Optionally logs various metrics and statistics.
|
| 284 |
+
|
| 285 |
+
Args:
|
| 286 |
+
paths (list[str]): List of game complete file paths for each rollout.
|
| 287 |
+
contexts (list[torch.Tensor]): List of context tensors for each rollout.
|
| 288 |
+
credits (list[torch.Tensor]): List of credit tensors (rewards/advantages) for each rollout.
|
| 289 |
+
action_masks (list[torch.Tensor]): List of action mask tensors for each rollout.
|
| 290 |
+
"""
|
| 291 |
+
with resource_logger_context(logger, "Apply reinforce step"):
|
| 292 |
+
self.policy.train()
|
| 293 |
+
mb_size = self.mini_batch_size
|
| 294 |
+
nb_rollouts = len(training_batch)
|
| 295 |
+
|
| 296 |
+
# Initialize running mean logs
|
| 297 |
+
running_mean_logs = {
|
| 298 |
+
"rl_objective": 0.0,
|
| 299 |
+
"policy_gradient_loss": 0.0,
|
| 300 |
+
"policy_gradient_norm": 0.0,
|
| 301 |
+
"log_probs": 0.0,
|
| 302 |
+
"credits": 0.0,
|
| 303 |
+
"entropy": 0.0,
|
| 304 |
+
"engine_log_probs_diff_clampfrac": 0.0,
|
| 305 |
+
"tis_imp_ratio": 0.0,
|
| 306 |
+
"ref_log_probs_diff_clampfrac": 0.0,
|
| 307 |
+
"higher_refprob_frac": 0.0,
|
| 308 |
+
"tis_imp_ratio_clampfrac": 0.0,
|
| 309 |
+
}
|
| 310 |
+
if self.entropy_coeff != 0.0:
|
| 311 |
+
running_mean_logs["entropy"] = 0.0
|
| 312 |
+
if self.kl_coeff != 0.0:
|
| 313 |
+
running_mean_logs["kl_divergence"] = 0.0
|
| 314 |
+
|
| 315 |
+
# Get total number of tokens generated
|
| 316 |
+
total_tokens_generated = 0
|
| 317 |
+
for att_mask in training_batch.batch_action_mask:
|
| 318 |
+
total_tokens_generated += att_mask.sum()
|
| 319 |
+
|
| 320 |
+
# Obtain loss normalization
|
| 321 |
+
if self.pg_loss_normalization == "nb_tokens":
|
| 322 |
+
normalization_factor = total_tokens_generated
|
| 323 |
+
elif self.pg_loss_normalization == "batch":
|
| 324 |
+
normalization_factor = np.ceil(nb_rollouts / mb_size).astype(int)
|
| 325 |
+
else:
|
| 326 |
+
raise ValueError(
|
| 327 |
+
f"Invalid pg_loss_normalization: {self.pg_loss_normalization}"
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
# Gradient accumulation for each mini-batch
|
| 331 |
+
for mb in range(0, nb_rollouts, mb_size):
|
| 332 |
+
logger.info(f"Processing mini-batch {mb} of {nb_rollouts}")
|
| 333 |
+
loss = 0.0
|
| 334 |
+
training_mb = training_batch[mb : mb + mb_size]
|
| 335 |
+
training_mb = training_mb.get_padded_tensors()
|
| 336 |
+
training_mb.to(self.device)
|
| 337 |
+
(
|
| 338 |
+
tokens_mb,
|
| 339 |
+
action_mask_mb,
|
| 340 |
+
entropy_mask_mb,
|
| 341 |
+
credits_mb,
|
| 342 |
+
engine_log_probs_mb,
|
| 343 |
+
timesteps_mb,
|
| 344 |
+
) = (
|
| 345 |
+
training_mb.batch_input_ids,
|
| 346 |
+
training_mb.batch_action_mask,
|
| 347 |
+
training_mb.batch_entropy_mask,
|
| 348 |
+
training_mb.batch_credits,
|
| 349 |
+
training_mb.batch_engine_log_probs,
|
| 350 |
+
training_mb.batch_timesteps,
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
# Next token prediction
|
| 354 |
+
contexts_mb = tokens_mb[:, :-1]
|
| 355 |
+
shifted_contexts_mb = tokens_mb[:, 1:]
|
| 356 |
+
action_mask_mb = action_mask_mb[:, 1:]
|
| 357 |
+
entropy_mask_mb = entropy_mask_mb[:, 1:]
|
| 358 |
+
credits_mb = credits_mb[:, 1:]
|
| 359 |
+
engine_log_probs_mb = engine_log_probs_mb[:, 1:]
|
| 360 |
+
timesteps_mb = timesteps_mb[:, 1:]
|
| 361 |
+
|
| 362 |
+
if self.enable_tokenwise_logging:
|
| 363 |
+
self.tokenwise_tally.set_action_mask(action_mask=action_mask_mb)
|
| 364 |
+
self.tokenwise_tally.set_range(range=(mb, mb + mb_size))
|
| 365 |
+
self.tokenwise_tally.add_contexts(contexts=contexts_mb)
|
| 366 |
+
self.tokenwise_tally.add_data(
|
| 367 |
+
metric_id="next_token",
|
| 368 |
+
metrics=shifted_contexts_mb,
|
| 369 |
+
to_tids=True,
|
| 370 |
+
)
|
| 371 |
+
self.tokenwise_tally.add_data(
|
| 372 |
+
metric_id="entropy_mask",
|
| 373 |
+
metrics=entropy_mask_mb,
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
if self.enable_tokenwise_logging:
|
| 377 |
+
self.tokenwise_tally.add_data(
|
| 378 |
+
metric_id="next_token_credit", metrics=credits_mb
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
# Forward pass + cast to FP-32 for higher prec.
|
| 382 |
+
# TODO: create attention mask if not relying on default (assume causal llm)
|
| 383 |
+
logits = self.policy(input_ids=contexts_mb)[0] # (B, S, V)
|
| 384 |
+
|
| 385 |
+
# Mask non-restricted tokens
|
| 386 |
+
if self.restrict_tokens is not None:
|
| 387 |
+
logits = self.mask_non_restricted_token_logits(logits)
|
| 388 |
+
|
| 389 |
+
logits /= self.temperature # (B, S, V)
|
| 390 |
+
|
| 391 |
+
# Compute new log probabilities
|
| 392 |
+
log_probs = F.log_softmax(logits, dim=-1) # (B, S, V)
|
| 393 |
+
|
| 394 |
+
# Get log probabilities of actions taken during rollouts
|
| 395 |
+
action_log_probs = log_probs.gather(
|
| 396 |
+
dim=-1, index=shifted_contexts_mb.unsqueeze(-1)
|
| 397 |
+
).squeeze(
|
| 398 |
+
-1
|
| 399 |
+
) # (B, S)
|
| 400 |
+
if self.pg_loss_normalization == "batch":
|
| 401 |
+
den_running_mean = action_mask_mb.sum() * normalization_factor
|
| 402 |
+
else:
|
| 403 |
+
den_running_mean = normalization_factor
|
| 404 |
+
running_mean_logs["log_probs"] += (
|
| 405 |
+
action_log_probs * action_mask_mb
|
| 406 |
+
).sum().item() / den_running_mean
|
| 407 |
+
running_mean_logs["credits"] += (
|
| 408 |
+
credits_mb * action_mask_mb
|
| 409 |
+
).sum().item() / den_running_mean
|
| 410 |
+
|
| 411 |
+
if self.enable_tokenwise_logging:
|
| 412 |
+
self.tokenwise_tally.add_data(
|
| 413 |
+
metric_id="next_token_log_prob",
|
| 414 |
+
metrics=action_log_probs,
|
| 415 |
+
)
|
| 416 |
+
self.tokenwise_tally.add_data(
|
| 417 |
+
metric_id="engine_next_token_log_prob",
|
| 418 |
+
metrics=engine_log_probs_mb,
|
| 419 |
+
)
|
| 420 |
+
self.tokenwise_tally.add_data(
|
| 421 |
+
metric_id="next_token_prob",
|
| 422 |
+
metrics=torch.exp(action_log_probs),
|
| 423 |
+
)
|
| 424 |
+
top_k_indices = torch.topk(logits, k=5, dim=-1).indices
|
| 425 |
+
self.tokenwise_tally.add_data(
|
| 426 |
+
metric_id=f"top_{5}_tids",
|
| 427 |
+
metrics=top_k_indices,
|
| 428 |
+
to_tids=True,
|
| 429 |
+
)
|
| 430 |
+
self.tokenwise_tally.add_data(
|
| 431 |
+
metric_id=f"top_{5}_probs",
|
| 432 |
+
metrics=torch.exp(log_probs).gather(
|
| 433 |
+
dim=-1, index=top_k_indices
|
| 434 |
+
),
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
rewarded_action_log_probs = (
|
| 438 |
+
action_mask_mb * credits_mb * action_log_probs
|
| 439 |
+
)
|
| 440 |
+
# (B, S)
|
| 441 |
+
INVALID_LOGPROB = 1.0
|
| 442 |
+
CLAMP_VALUE = 40.0
|
| 443 |
+
masked_action_log_probs = torch.masked_fill(
|
| 444 |
+
action_log_probs, ~action_mask_mb, INVALID_LOGPROB
|
| 445 |
+
)
|
| 446 |
+
masked_engine_log_probs = torch.masked_fill(
|
| 447 |
+
engine_log_probs_mb, ~action_mask_mb, INVALID_LOGPROB
|
| 448 |
+
)
|
| 449 |
+
with torch.no_grad():
|
| 450 |
+
action_engine_log_probs_diff = (
|
| 451 |
+
masked_action_log_probs - masked_engine_log_probs
|
| 452 |
+
).clamp(-CLAMP_VALUE, CLAMP_VALUE)
|
| 453 |
+
running_mean_logs["engine_log_probs_diff_clampfrac"] += (
|
| 454 |
+
action_engine_log_probs_diff.abs()
|
| 455 |
+
.eq(CLAMP_VALUE)
|
| 456 |
+
.float()
|
| 457 |
+
.sum()
|
| 458 |
+
.item()
|
| 459 |
+
/ den_running_mean
|
| 460 |
+
)
|
| 461 |
+
if self.importance_sampling_strategy == "per_sequence":
|
| 462 |
+
tis_imp_ratio = torch.zeros_like(action_engine_log_probs_diff)
|
| 463 |
+
for mb_idx in range(action_engine_log_probs_diff.shape[0]):
|
| 464 |
+
valid_token_mask = action_mask_mb[mb_idx]
|
| 465 |
+
timestep_ids = timesteps_mb[mb_idx][valid_token_mask]
|
| 466 |
+
timestep_logprob_diffs = action_engine_log_probs_diff[mb_idx][
|
| 467 |
+
valid_token_mask
|
| 468 |
+
]
|
| 469 |
+
max_timestep = int(timestep_ids.max().item()) + 1
|
| 470 |
+
timestep_sums = torch.zeros(
|
| 471 |
+
max_timestep,
|
| 472 |
+
device=action_engine_log_probs_diff.device,
|
| 473 |
+
dtype=action_engine_log_probs_diff.dtype,
|
| 474 |
+
)
|
| 475 |
+
timestep_sums.scatter_add_(
|
| 476 |
+
0, timestep_ids, timestep_logprob_diffs
|
| 477 |
+
)
|
| 478 |
+
timestep_ratios = torch.exp(timestep_sums)
|
| 479 |
+
tis_imp_ratio[
|
| 480 |
+
mb_idx, valid_token_mask
|
| 481 |
+
] = timestep_ratios.gather(0, timestep_ids)
|
| 482 |
+
else:
|
| 483 |
+
tis_imp_ratio = torch.exp(action_engine_log_probs_diff)
|
| 484 |
+
running_mean_logs["tis_imp_ratio"] += (
|
| 485 |
+
tis_imp_ratio * action_mask_mb
|
| 486 |
+
).sum().item() / den_running_mean
|
| 487 |
+
if self.truncated_importance_sampling_ratio_cap > 0.0:
|
| 488 |
+
tis_imp_ratio = torch.clamp(
|
| 489 |
+
tis_imp_ratio, max=self.truncated_importance_sampling_ratio_cap
|
| 490 |
+
)
|
| 491 |
+
running_mean_logs["tis_imp_ratio_clampfrac"] += (
|
| 492 |
+
tis_imp_ratio.eq(self.truncated_importance_sampling_ratio_cap)
|
| 493 |
+
.float()
|
| 494 |
+
.sum()
|
| 495 |
+
.item()
|
| 496 |
+
) / den_running_mean
|
| 497 |
+
rewarded_action_log_probs = (
|
| 498 |
+
rewarded_action_log_probs * tis_imp_ratio
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
if self.enable_tokenwise_logging:
|
| 502 |
+
self.tokenwise_tally.add_data(
|
| 503 |
+
metric_id="next_token_clogπ",
|
| 504 |
+
metrics=rewarded_action_log_probs,
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
# Add value term to loss
|
| 508 |
+
if self.pg_loss_normalization == "batch":
|
| 509 |
+
nb_act_tokens = action_mask_mb.sum()
|
| 510 |
+
mb_value = -rewarded_action_log_probs.sum() / nb_act_tokens
|
| 511 |
+
else:
|
| 512 |
+
mb_value = -rewarded_action_log_probs.sum()
|
| 513 |
+
|
| 514 |
+
loss += mb_value
|
| 515 |
+
running_mean_logs["rl_objective"] += mb_value.item() / den_running_mean
|
| 516 |
+
|
| 517 |
+
# -------------------------------------------------
|
| 518 |
+
# Entropy Regularization
|
| 519 |
+
# -------------------------------------------------
|
| 520 |
+
# Only apply entropy on distribution defined over most probable tokens
|
| 521 |
+
if self.entropy_topk is not None:
|
| 522 |
+
top_k_indices = torch.topk(
|
| 523 |
+
logits, k=self.entropy_topk, dim=-1
|
| 524 |
+
).indices
|
| 525 |
+
entropy_logits = logits.gather(dim=-1, index=top_k_indices)
|
| 526 |
+
else:
|
| 527 |
+
entropy_logits = logits
|
| 528 |
+
|
| 529 |
+
token_entropy_terms = -F.softmax(
|
| 530 |
+
entropy_logits, dim=-1
|
| 531 |
+
) * F.log_softmax(
|
| 532 |
+
entropy_logits, dim=-1
|
| 533 |
+
) # (B, S, T)
|
| 534 |
+
token_entropy_terms *= (
|
| 535 |
+
action_mask_mb[:, :, None] * entropy_mask_mb[:, :, None]
|
| 536 |
+
) # only get loss on specific action tokens
|
| 537 |
+
|
| 538 |
+
mb_entropy = token_entropy_terms.sum(dim=-1)
|
| 539 |
+
|
| 540 |
+
if self.enable_tokenwise_logging:
|
| 541 |
+
self.tokenwise_tally.add_data(
|
| 542 |
+
metric_id="entropy",
|
| 543 |
+
metrics=mb_entropy,
|
| 544 |
+
)
|
| 545 |
+
if self.pg_loss_normalization == "batch":
|
| 546 |
+
nb_act_tokens = action_mask_mb.sum()
|
| 547 |
+
mb_entropy = -mb_entropy.sum() / nb_act_tokens
|
| 548 |
+
else:
|
| 549 |
+
mb_entropy = -mb_entropy.sum()
|
| 550 |
+
running_mean_logs["entropy"] += -mb_entropy.item() / den_running_mean
|
| 551 |
+
if self.entropy_coeff != 0.0:
|
| 552 |
+
mb_entropy *= self.entropy_coeff
|
| 553 |
+
loss += mb_entropy
|
| 554 |
+
|
| 555 |
+
# -------------------------------------------------
|
| 556 |
+
# KL-DIVERGENCE
|
| 557 |
+
# -------------------------------------------------
|
| 558 |
+
if self.kl_coeff != 0.0:
|
| 559 |
+
ref_model_logits = self.policy.get_base_model_logits(contexts_mb)
|
| 560 |
+
ref_model_logits = ref_model_logits / self.temperature
|
| 561 |
+
# (B, S, V)
|
| 562 |
+
ref_model_logits = self.mask_non_restricted_token_logits(
|
| 563 |
+
logits=ref_model_logits
|
| 564 |
+
)
|
| 565 |
+
# (B, S, V)
|
| 566 |
+
ref_model_log_probs = F.log_softmax(ref_model_logits, dim=-1)
|
| 567 |
+
# (B, S, V)
|
| 568 |
+
ref_model_action_log_probs = ref_model_log_probs.gather(
|
| 569 |
+
dim=-1, index=shifted_contexts_mb.unsqueeze(-1)
|
| 570 |
+
).squeeze(
|
| 571 |
+
-1
|
| 572 |
+
) # (B,S)
|
| 573 |
+
# Approximating KL Divergence (see refs in docstring)
|
| 574 |
+
# Ref 1: http://joschu.net/blog/kl-approx.html
|
| 575 |
+
# Ref 2: https://github.dev/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L1332
|
| 576 |
+
masked_ref_model_action_log_probs = torch.masked_fill(
|
| 577 |
+
ref_model_action_log_probs, ~action_mask_mb, INVALID_LOGPROB
|
| 578 |
+
)
|
| 579 |
+
action_log_probs_diff = (
|
| 580 |
+
masked_ref_model_action_log_probs - masked_action_log_probs
|
| 581 |
+
).clamp(-CLAMP_VALUE, CLAMP_VALUE)
|
| 582 |
+
running_mean_logs["ref_log_probs_diff_clampfrac"] += (
|
| 583 |
+
action_log_probs_diff.abs().eq(CLAMP_VALUE).float().sum().item()
|
| 584 |
+
/ den_running_mean
|
| 585 |
+
)
|
| 586 |
+
if self.filter_higher_refprob_tokens_kl:
|
| 587 |
+
higher_refprob_tokens_mask = action_log_probs_diff > 0.0
|
| 588 |
+
running_mean_logs["higher_refprob_frac"] += (
|
| 589 |
+
higher_refprob_tokens_mask.sum().item() / den_running_mean
|
| 590 |
+
)
|
| 591 |
+
action_log_probs_diff = action_log_probs_diff * (
|
| 592 |
+
~higher_refprob_tokens_mask
|
| 593 |
+
)
|
| 594 |
+
kl_div = torch.expm1(action_log_probs_diff) - action_log_probs_diff
|
| 595 |
+
kl_div *= action_mask_mb # We only care about KLD of action tokens
|
| 596 |
+
if self.truncated_importance_sampling_ratio_cap > 0.0:
|
| 597 |
+
kl_div = kl_div * tis_imp_ratio
|
| 598 |
+
kl_div *= self.kl_coeff
|
| 599 |
+
if self.enable_tokenwise_logging:
|
| 600 |
+
self.tokenwise_tally.add_data(
|
| 601 |
+
metric_id="ref_model_next_token_log_prob",
|
| 602 |
+
metrics=ref_model_action_log_probs,
|
| 603 |
+
)
|
| 604 |
+
self.tokenwise_tally.add_data(
|
| 605 |
+
metric_id="kl_divergence",
|
| 606 |
+
metrics=kl_div,
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
if self.pg_loss_normalization == "batch":
|
| 610 |
+
nb_act_tokens = action_mask_mb.sum()
|
| 611 |
+
mb_kl = kl_div.sum() / nb_act_tokens
|
| 612 |
+
else:
|
| 613 |
+
mb_kl = kl_div.sum()
|
| 614 |
+
running_mean_logs["kl_divergence"] += (
|
| 615 |
+
mb_kl.item() / den_running_mean
|
| 616 |
+
)
|
| 617 |
+
loss += mb_kl
|
| 618 |
+
|
| 619 |
+
# Accumulate gradient
|
| 620 |
+
running_mean_logs["policy_gradient_loss"] += (
|
| 621 |
+
loss.item() / den_running_mean
|
| 622 |
+
)
|
| 623 |
+
loss /= normalization_factor
|
| 624 |
+
self.accelerator.backward(loss)
|
| 625 |
+
|
| 626 |
+
# ensure gpu memory is freed
|
| 627 |
+
del training_mb
|
| 628 |
+
del log_probs
|
| 629 |
+
del logits
|
| 630 |
+
del loss
|
| 631 |
+
del action_log_probs
|
| 632 |
+
del rewarded_action_log_probs
|
| 633 |
+
|
| 634 |
+
logger.info(
|
| 635 |
+
f"Accumulated the policy gradient loss for {total_tokens_generated} tokens."
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
# Clip gradients and take step
|
| 639 |
+
if self.gradient_clipping is not None:
|
| 640 |
+
grad_norm = self.accelerator.clip_grad_norm_(
|
| 641 |
+
self.policy.parameters(), self.gradient_clipping
|
| 642 |
+
)
|
| 643 |
+
running_mean_logs["policy_gradient_norm"] += grad_norm.item()
|
| 644 |
+
|
| 645 |
+
# Take step
|
| 646 |
+
self.policy_optimizer.step()
|
| 647 |
+
self.policy_optimizer.zero_grad()
|
| 648 |
+
|
| 649 |
+
# Store logs
|
| 650 |
+
for key, value in running_mean_logs.items():
|
| 651 |
+
self.tally.add_metric(path=key, metric=value)
|
| 652 |
+
|
| 653 |
+
# Clear
|
| 654 |
+
# TODO: verify
|
| 655 |
+
self.accelerator.clear(self.policy, self.policy_optimizer)
|
| 656 |
+
import gc
|
| 657 |
+
|
| 658 |
+
gc.collect()
|
| 659 |
+
torch.cuda.empty_cache()
|
| 660 |
+
return running_mean_logs
|
| 661 |
+
|
| 662 |
+
def get_advantages_with_critic_gradient_accumulation(
|
| 663 |
+
self, trajectories: TrajectoryBatch, critic_loss_scaling_factor: float = 2.0
|
| 664 |
+
) -> torch.FloatTensor:
|
| 665 |
+
"""
|
| 666 |
+
TOWRITE
|
| 667 |
+
Uses GAE if enabled, otherwise uses Monte Carlo returns.
|
| 668 |
+
Optionally trains the critic if GAE is used.
|
| 669 |
+
Returns:
|
| 670 |
+
advantages: NestedFloatTensors
|
| 671 |
+
"""
|
| 672 |
+
|
| 673 |
+
mb_size = self.mini_batch_size
|
| 674 |
+
batch_size = trajectories.rollout_ids.shape[0]
|
| 675 |
+
agent_id = trajectories.agent_ids[0]
|
| 676 |
+
batch_rewards = trajectories.batch_rewards
|
| 677 |
+
|
| 678 |
+
######################################
|
| 679 |
+
# use critic for advantage estimation
|
| 680 |
+
######################################
|
| 681 |
+
if self.use_gae:
|
| 682 |
+
if "buffer" in agent_id:
|
| 683 |
+
self.critic.eval()
|
| 684 |
+
training = False
|
| 685 |
+
else:
|
| 686 |
+
self.critic.train()
|
| 687 |
+
training = True
|
| 688 |
+
advantages = []
|
| 689 |
+
# critic_loss_scaling_factor comes learning single critic for two agents
|
| 690 |
+
normalization_factor = (
|
| 691 |
+
np.ceil(batch_size / mb_size).astype(int) * critic_loss_scaling_factor
|
| 692 |
+
)
|
| 693 |
+
# For each minibatch
|
| 694 |
+
for mb in range(0, batch_size, mb_size):
|
| 695 |
+
trajectory_mb = trajectories[mb : mb + mb_size]
|
| 696 |
+
trajectory_mb.to(self.device)
|
| 697 |
+
rewards_mb = trajectory_mb.batch_rewards
|
| 698 |
+
(
|
| 699 |
+
tokens_mb,
|
| 700 |
+
state_ends_mask_mb,
|
| 701 |
+
timestep_counts,
|
| 702 |
+
) = trajectory_mb.get_padded_tensors_for_critic()
|
| 703 |
+
# critic causal attention up to end flags
|
| 704 |
+
if training:
|
| 705 |
+
vals_estimate_full = self.critic(tokens_mb)
|
| 706 |
+
else:
|
| 707 |
+
with torch.no_grad():
|
| 708 |
+
vals_estimate_full = self.critic(tokens_mb)
|
| 709 |
+
|
| 710 |
+
# if vals_estimate_full.dim() == 3:
|
| 711 |
+
# vals_estimate_full = vals_estimate_full.squeeze(-1)
|
| 712 |
+
|
| 713 |
+
# Select only positions where states end, per sample → list of (jT,)
|
| 714 |
+
B = tokens_mb.shape[0]
|
| 715 |
+
vals_list = [
|
| 716 |
+
vals_estimate_full[b][state_ends_mask_mb[b]] for b in range(B)
|
| 717 |
+
]
|
| 718 |
+
|
| 719 |
+
# Pad to (B, max_jT) = (B, S)
|
| 720 |
+
vals_estimate_mb = pad_sequence(
|
| 721 |
+
vals_list, batch_first=True, padding_value=0.0
|
| 722 |
+
)
|
| 723 |
+
dtype = vals_estimate_mb.dtype
|
| 724 |
+
rewards_mb = pad_sequence(
|
| 725 |
+
rewards_mb, batch_first=True, padding_value=0.0
|
| 726 |
+
).to(
|
| 727 |
+
dtype=dtype
|
| 728 |
+
) # (B, S)
|
| 729 |
+
self.rollout_tally.add_metric(
|
| 730 |
+
path=["batch_rewards"],
|
| 731 |
+
rollout_tally_item=RolloutTallyItem(
|
| 732 |
+
crn_ids=trajectory_mb.crn_ids,
|
| 733 |
+
rollout_ids=trajectory_mb.rollout_ids,
|
| 734 |
+
agent_ids=trajectory_mb.agent_ids,
|
| 735 |
+
metric_matrix=rewards_mb,
|
| 736 |
+
),
|
| 737 |
+
)
|
| 738 |
+
if self.reward_normalizing_constant != 1.0:
|
| 739 |
+
rewards_mb /= self.reward_normalizing_constant
|
| 740 |
+
|
| 741 |
+
det_vals_estimate_mb = vals_estimate_mb.detach() # (B, max_jT)
|
| 742 |
+
self.rollout_tally.add_metric(
|
| 743 |
+
path=["mb_value_estimates_critic"],
|
| 744 |
+
rollout_tally_item=RolloutTallyItem(
|
| 745 |
+
crn_ids=trajectory_mb.crn_ids,
|
| 746 |
+
rollout_ids=trajectory_mb.rollout_ids,
|
| 747 |
+
agent_ids=trajectory_mb.agent_ids,
|
| 748 |
+
metric_matrix=det_vals_estimate_mb,
|
| 749 |
+
),
|
| 750 |
+
)
|
| 751 |
+
|
| 752 |
+
# Append a 0 value to the end of the value estimates
|
| 753 |
+
if det_vals_estimate_mb.shape[1] == rewards_mb.shape[1]:
|
| 754 |
+
Bsize = det_vals_estimate_mb.shape[0]
|
| 755 |
+
device = det_vals_estimate_mb.device
|
| 756 |
+
dtype = det_vals_estimate_mb.dtype
|
| 757 |
+
det_vals_estimate_mb = torch.cat(
|
| 758 |
+
[
|
| 759 |
+
det_vals_estimate_mb,
|
| 760 |
+
torch.zeros((Bsize, 1), device=device, dtype=dtype),
|
| 761 |
+
],
|
| 762 |
+
dim=1,
|
| 763 |
+
) # (B, max_jT+1)
|
| 764 |
+
else:
|
| 765 |
+
raise ValueError(
|
| 766 |
+
"Incompatible shapes for value estimates and rewards."
|
| 767 |
+
)
|
| 768 |
+
|
| 769 |
+
# Get annealed lambda
|
| 770 |
+
if self.use_gae_lambda_annealing:
|
| 771 |
+
annealing_constant = self.gae_lambda_annealing_method(
|
| 772 |
+
step=self.trainer_annealing_state.annealing_step_counter
|
| 773 |
+
)
|
| 774 |
+
annealed_lambda = (
|
| 775 |
+
self.gae_lambda_annealing_limit * annealing_constant
|
| 776 |
+
)
|
| 777 |
+
self.tally.add_metric(
|
| 778 |
+
path="annealed_lambda", metric=annealed_lambda
|
| 779 |
+
)
|
| 780 |
+
else:
|
| 781 |
+
annealed_lambda = self.gae_lambda_annealing_limit
|
| 782 |
+
|
| 783 |
+
# Get GAE advantages
|
| 784 |
+
gae_advantages = get_generalized_advantage_estimates(
|
| 785 |
+
rewards=rewards_mb,
|
| 786 |
+
value_estimates=det_vals_estimate_mb,
|
| 787 |
+
discount_factor=self.discount_factor,
|
| 788 |
+
lambda_coef=annealed_lambda,
|
| 789 |
+
) # (B, max_jT)
|
| 790 |
+
self.rollout_tally.add_metric(
|
| 791 |
+
path=["mb_gae_advantages"],
|
| 792 |
+
rollout_tally_item=RolloutTallyItem(
|
| 793 |
+
crn_ids=trajectory_mb.crn_ids,
|
| 794 |
+
rollout_ids=trajectory_mb.rollout_ids,
|
| 795 |
+
agent_ids=trajectory_mb.agent_ids,
|
| 796 |
+
metric_matrix=gae_advantages,
|
| 797 |
+
),
|
| 798 |
+
)
|
| 799 |
+
if training:
|
| 800 |
+
targets = (
|
| 801 |
+
gae_advantages.to(dtype=dtype) + det_vals_estimate_mb[:, :-1]
|
| 802 |
+
) # (B, max_jT) # A(s, a, b) + V(s) = Q(s, a, b)
|
| 803 |
+
self.rollout_tally.add_metric(
|
| 804 |
+
path=["mb_targets_critic"],
|
| 805 |
+
rollout_tally_item=RolloutTallyItem(
|
| 806 |
+
crn_ids=trajectory_mb.crn_ids,
|
| 807 |
+
rollout_ids=trajectory_mb.rollout_ids,
|
| 808 |
+
agent_ids=trajectory_mb.agent_ids,
|
| 809 |
+
metric_matrix=targets,
|
| 810 |
+
),
|
| 811 |
+
)
|
| 812 |
+
if self.critic_loss_type == "mse":
|
| 813 |
+
loss = F.mse_loss(
|
| 814 |
+
input=vals_estimate_mb,
|
| 815 |
+
target=targets,
|
| 816 |
+
)
|
| 817 |
+
elif self.critic_loss_type == "huber":
|
| 818 |
+
loss = F.huber_loss(
|
| 819 |
+
input=vals_estimate_mb,
|
| 820 |
+
target=targets,
|
| 821 |
+
)
|
| 822 |
+
self.tally.add_metric(path=["mb_critic_loss"], metric=loss.item())
|
| 823 |
+
# Accumulate gradient
|
| 824 |
+
loss /= normalization_factor
|
| 825 |
+
self.accelerator.backward(loss)
|
| 826 |
+
del loss
|
| 827 |
+
del targets
|
| 828 |
+
del vals_estimate_mb
|
| 829 |
+
del trajectory_mb
|
| 830 |
+
del vals_estimate_full
|
| 831 |
+
|
| 832 |
+
# Get jagged back using timestep_counts
|
| 833 |
+
advantages.extend(
|
| 834 |
+
[gae_advantages[i, : timestep_counts[i]] for i in range(B)]
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
######################################
|
| 838 |
+
# use exclusively Monte Carlo returns & rloo for advantage estimation
|
| 839 |
+
######################################
|
| 840 |
+
else:
|
| 841 |
+
lengths = [len(c) for c in batch_rewards]
|
| 842 |
+
padded_rewards = pad_sequence(
|
| 843 |
+
batch_rewards, batch_first=True, padding_value=0.0
|
| 844 |
+
)
|
| 845 |
+
self.rollout_tally.add_metric(
|
| 846 |
+
path=["mb_rewards"],
|
| 847 |
+
rollout_tally_item=RolloutTallyItem(
|
| 848 |
+
crn_ids=trajectories.crn_ids,
|
| 849 |
+
rollout_ids=trajectories.rollout_ids,
|
| 850 |
+
agent_ids=trajectories.agent_ids,
|
| 851 |
+
metric_matrix=padded_rewards,
|
| 852 |
+
),
|
| 853 |
+
)
|
| 854 |
+
if self.reward_normalizing_constant != 1.0:
|
| 855 |
+
padded_rewards /= self.reward_normalizing_constant
|
| 856 |
+
padded_advantages = get_discounted_returns(
|
| 857 |
+
rewards=padded_rewards,
|
| 858 |
+
discount_factor=self.discount_factor,
|
| 859 |
+
) # no baseline for now
|
| 860 |
+
if self.use_rloo:
|
| 861 |
+
is_grouped_by_rng = (
|
| 862 |
+
trajectories.crn_ids.unique().shape[0]
|
| 863 |
+
!= trajectories.crn_ids.shape[0]
|
| 864 |
+
)
|
| 865 |
+
if is_grouped_by_rng:
|
| 866 |
+
for crn_id in trajectories.crn_ids.unique():
|
| 867 |
+
rng_mask = trajectories.crn_ids == crn_id
|
| 868 |
+
rng_advantages = padded_advantages[rng_mask]
|
| 869 |
+
rng_advantages, _ = get_rloo_credits(credits=rng_advantages)
|
| 870 |
+
padded_advantages[rng_mask] = rng_advantages
|
| 871 |
+
else:
|
| 872 |
+
padded_advantages, _ = get_rloo_credits(credits=padded_advantages)
|
| 873 |
+
self.rollout_tally.add_metric(
|
| 874 |
+
path=["mb_rloo_advantages"],
|
| 875 |
+
rollout_tally_item=RolloutTallyItem(
|
| 876 |
+
crn_ids=trajectories.crn_ids,
|
| 877 |
+
rollout_ids=trajectories.rollout_ids,
|
| 878 |
+
agent_ids=trajectories.agent_ids,
|
| 879 |
+
metric_matrix=padded_advantages,
|
| 880 |
+
),
|
| 881 |
+
)
|
| 882 |
+
advantages = [
|
| 883 |
+
padded_advantages[i, : lengths[i]]
|
| 884 |
+
for i in range(padded_advantages.shape[0])
|
| 885 |
+
]
|
| 886 |
+
|
| 887 |
+
if self.whiten_advantages_time_step_wise or self.whiten_advantages:
|
| 888 |
+
lengths = [len(c) for c in advantages]
|
| 889 |
+
padded_advantages = pad_sequence(
|
| 890 |
+
advantages, batch_first=True, padding_value=0.0
|
| 891 |
+
)
|
| 892 |
+
if self.whiten_advantages_time_step_wise:
|
| 893 |
+
whitened_padded_advantages = whiten_advantages_time_step_wise(
|
| 894 |
+
padded_advantages
|
| 895 |
+
)
|
| 896 |
+
path = ["mb_whitened_advantages_time_step_wise"]
|
| 897 |
+
elif self.whiten_advantages:
|
| 898 |
+
whitened_padded_advantages = whiten_advantages(padded_advantages)
|
| 899 |
+
path = ["mb_whitened_advantages"]
|
| 900 |
+
self.rollout_tally.add_metric(
|
| 901 |
+
path=path,
|
| 902 |
+
rollout_tally_item=RolloutTallyItem(
|
| 903 |
+
crn_ids=trajectories.crn_ids,
|
| 904 |
+
rollout_ids=trajectories.rollout_ids,
|
| 905 |
+
agent_ids=trajectories.agent_ids,
|
| 906 |
+
metric_matrix=whitened_padded_advantages,
|
| 907 |
+
),
|
| 908 |
+
)
|
| 909 |
+
advantages = [
|
| 910 |
+
whitened_padded_advantages[i, : lengths[i]]
|
| 911 |
+
for i in range(whitened_padded_advantages.shape[0])
|
| 912 |
+
]
|
| 913 |
+
|
| 914 |
+
self.trainer_annealing_state.annealing_step_counter += 1
|
| 915 |
+
|
| 916 |
+
return advantages
|
| 917 |
+
|
| 918 |
+
@abstractmethod
|
| 919 |
+
def set_agent_trajectory_data(
|
| 920 |
+
self, agent_id: str, roots: list[RolloutTreeRootNode]
|
| 921 |
+
) -> None:
|
| 922 |
+
"""
|
| 923 |
+
TOWRITE
|
| 924 |
+
"""
|
| 925 |
+
pass
|
| 926 |
+
|
| 927 |
+
def set_trajectory_data(
|
| 928 |
+
self, roots: list[RolloutTreeRootNode], agent_ids: list[str]
|
| 929 |
+
) -> None:
|
| 930 |
+
"""
|
| 931 |
+
TOWRITE
|
| 932 |
+
"""
|
| 933 |
+
for agent_id in agent_ids:
|
| 934 |
+
self.set_agent_trajectory_data(agent_id, roots)
|
| 935 |
+
|
| 936 |
+
@abstractmethod
|
| 937 |
+
def share_advantage_data(self) -> list[AdvantagePacket]:
|
| 938 |
+
pass
|
| 939 |
+
|
| 940 |
+
@abstractmethod
|
| 941 |
+
def receive_advantage_data(self, advantage_packets: list[AdvantagePacket]) -> None:
|
| 942 |
+
pass
|
| 943 |
+
|
| 944 |
+
def set_policy_gradient_data(self, agent_ids: list[str]) -> None:
|
| 945 |
+
"""
|
| 946 |
+
Already set earlier # TODO: make it separate and clean
|
| 947 |
+
"""
|
| 948 |
+
self.policy_gradient_data = None
|
| 949 |
+
# for agent_id, trajectory_batch in self.training_data.items():
|
| 950 |
+
# if "buffer" in agent_id:
|
| 951 |
+
# continue
|
| 952 |
+
for agent_id in agent_ids:
|
| 953 |
+
assert "buffer" not in agent_id, "Buffer agents do not train policy"
|
| 954 |
+
trajectory_batch = self.training_data[agent_id]
|
| 955 |
+
tokenwise_batch_credits = get_tokenwise_credits(
|
| 956 |
+
batch_timesteps=trajectory_batch.batch_timesteps,
|
| 957 |
+
batch_credits=trajectory_batch.batch_credits,
|
| 958 |
+
)
|
| 959 |
+
policy_gradient_data = TrainingBatch(
|
| 960 |
+
rollout_ids=trajectory_batch.rollout_ids,
|
| 961 |
+
batch_input_ids=trajectory_batch.batch_input_ids,
|
| 962 |
+
batch_action_mask=trajectory_batch.batch_action_mask,
|
| 963 |
+
batch_entropy_mask=trajectory_batch.batch_entropy_mask,
|
| 964 |
+
batch_credits=tokenwise_batch_credits,
|
| 965 |
+
batch_engine_log_probs=trajectory_batch.batch_engine_log_probs,
|
| 966 |
+
batch_timesteps=trajectory_batch.batch_timesteps,
|
| 967 |
+
)
|
| 968 |
+
if self.policy_gradient_data is None:
|
| 969 |
+
self.policy_gradient_data = policy_gradient_data
|
| 970 |
+
else:
|
| 971 |
+
self.policy_gradient_data.append(policy_gradient_data)
|
| 972 |
+
|
| 973 |
+
self.training_data = {}
|
| 974 |
+
self.tokenwise_tally = ContextualizedTokenwiseTally(
|
| 975 |
+
tokenizer=self.tokenizer,
|
| 976 |
+
paths=self.debug_path_list,
|
| 977 |
+
)
|
| 978 |
+
|
| 979 |
+
def train(self) -> None:
|
| 980 |
+
"""
|
| 981 |
+
TOWRITE
|
| 982 |
+
"""
|
| 983 |
+
assert self.policy_gradient_data is not None, "Policy gradient data is not set"
|
| 984 |
+
if self.critic_optimizer is not None:
|
| 985 |
+
if self.gradient_clipping is not None:
|
| 986 |
+
grad_norm = self.accelerator.clip_grad_norm_(
|
| 987 |
+
self.critic.parameters(), self.gradient_clipping
|
| 988 |
+
)
|
| 989 |
+
self.tally.add_metric(
|
| 990 |
+
path="gradient_norm_critic", metric=grad_norm.item()
|
| 991 |
+
)
|
| 992 |
+
# Take step
|
| 993 |
+
self.critic_optimizer.step()
|
| 994 |
+
self.critic_optimizer.zero_grad()
|
| 995 |
+
self.accelerator.clear(self.critic, self.critic_optimizer)
|
| 996 |
+
import gc
|
| 997 |
+
|
| 998 |
+
gc.collect()
|
| 999 |
+
torch.cuda.empty_cache()
|
| 1000 |
+
running_mean_logs = self.apply_reinforce_step(
|
| 1001 |
+
training_batch=self.policy_gradient_data
|
| 1002 |
+
)
|
| 1003 |
+
return running_mean_logs
|
| 1004 |
+
|
| 1005 |
+
def export_training_tally(self, identifier: str, folder: str) -> None:
|
| 1006 |
+
"""
|
| 1007 |
+
Saves and resets the collected training metrics using the tally object.
|
| 1008 |
+
"""
|
| 1009 |
+
os.makedirs(folder, exist_ok=True)
|
| 1010 |
+
self.tally.save(identifier=identifier, folder=folder)
|
| 1011 |
+
self.tokenwise_tally.save(
|
| 1012 |
+
path=os.path.join(folder, f"{identifier}_tokenwise.csv")
|
| 1013 |
+
)
|
| 1014 |
+
self.rollout_tally.save(identifier=identifier, folder=folder)
|
| 1015 |
+
self.tally.reset()
|
| 1016 |
+
self.tokenwise_tally = None
|
| 1017 |
+
self.rollout_tally.reset()
|
| 1018 |
+
self.debug_path_list = []
|
| 1019 |
+
|
| 1020 |
+
def export_optimizer_states(self) -> None:
|
| 1021 |
+
"""
|
| 1022 |
+
Saves the optimizer states for both the main model and critic (if it exists).
|
| 1023 |
+
"""
|
| 1024 |
+
try:
|
| 1025 |
+
os.makedirs(self.save_path, exist_ok=True)
|
| 1026 |
+
|
| 1027 |
+
torch.save(self.policy_optimizer.state_dict(), self.policy_optimizer_path)
|
| 1028 |
+
logger.info(f"Saved main optimizer state to {self.policy_optimizer_path}")
|
| 1029 |
+
|
| 1030 |
+
if self.critic_optimizer is not None:
|
| 1031 |
+
torch.save(
|
| 1032 |
+
self.critic_optimizer.state_dict(), self.critic_optimizer_path
|
| 1033 |
+
)
|
| 1034 |
+
logger.info(
|
| 1035 |
+
f"Saved critic optimizer state to {self.critic_optimizer_path}"
|
| 1036 |
+
)
|
| 1037 |
+
except Exception as e:
|
| 1038 |
+
logger.error(f"Error saving optimizer states: {str(e)}")
|
| 1039 |
+
raise
|
| 1040 |
+
|
| 1041 |
+
def export_trainer_annealing_state(self) -> None:
|
| 1042 |
+
"""
|
| 1043 |
+
Saves the trainer state.
|
| 1044 |
+
"""
|
| 1045 |
+
with open(self.trainer_annealing_state_path, "wb") as f:
|
| 1046 |
+
pickle.dump(self.trainer_annealing_state, f)
|
| 1047 |
+
logger.info(f"Saved trainer state to {self.trainer_annealing_state_path}")
|
| 1048 |
+
|
| 1049 |
+
def export_trainer_states(self) -> None:
|
| 1050 |
+
"""
|
| 1051 |
+
Saves the trainer states.
|
| 1052 |
+
"""
|
| 1053 |
+
self.export_optimizer_states()
|
| 1054 |
+
self.export_trainer_annealing_state()
|
src_code_for_reproducibility/training/trainer_independent.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
from typing import Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from accelerate import Accelerator
|
| 12 |
+
from pandas._libs.tslibs.offsets import CBMonthBegin
|
| 13 |
+
from peft import LoraConfig
|
| 14 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 15 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 16 |
+
|
| 17 |
+
from mllm.markov_games.rollout_tree import *
|
| 18 |
+
from mllm.markov_games.rollout_tree import RolloutTreeRootNode
|
| 19 |
+
from mllm.training.credit_methods import (
|
| 20 |
+
get_discounted_returns,
|
| 21 |
+
get_discounted_state_visitation_credits,
|
| 22 |
+
get_generalized_advantage_estimates,
|
| 23 |
+
get_rloo_credits,
|
| 24 |
+
)
|
| 25 |
+
from mllm.training.tally_metrics import Tally
|
| 26 |
+
from mllm.training.tally_tokenwise import ContextualizedTokenwiseTally
|
| 27 |
+
from mllm.training.tokenize_chats import *
|
| 28 |
+
from mllm.training.tokenize_chats import process_training_chat
|
| 29 |
+
from mllm.training.trainer_common import BaseTrainer
|
| 30 |
+
from mllm.training.training_data_utils import *
|
| 31 |
+
from mllm.training.training_data_utils import (
|
| 32 |
+
TrainingBatch,
|
| 33 |
+
TrajectoryBatch,
|
| 34 |
+
get_tokenwise_credits,
|
| 35 |
+
)
|
| 36 |
+
from mllm.utils.resource_context import resource_logger_context
|
| 37 |
+
|
| 38 |
+
logger = logging.getLogger(__name__)
|
| 39 |
+
logger.addHandler(logging.StreamHandler(sys.stdout))
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class TrainingData:
|
| 44 |
+
agent_id: str
|
| 45 |
+
main_data: TrajectoryBatch
|
| 46 |
+
# list-of-tensors: per rollout advantages with length jT
|
| 47 |
+
main_advantages: list[torch.FloatTensor] | None = None
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class TrainerNaive(BaseTrainer):
|
| 51 |
+
def set_agent_trajectory_data(
|
| 52 |
+
self, agent_id: str, roots: list[RolloutTreeRootNode]
|
| 53 |
+
) -> None:
|
| 54 |
+
"""
|
| 55 |
+
TOWRITE
|
| 56 |
+
"""
|
| 57 |
+
# TODO: append to current batch data instead, else we will only train for one agent!
|
| 58 |
+
self.policy_gradient_data = None
|
| 59 |
+
|
| 60 |
+
# Tensorize Chats
|
| 61 |
+
rollout_ids = []
|
| 62 |
+
crn_ids = [] # common random number id
|
| 63 |
+
batch_input_ids = []
|
| 64 |
+
batch_action_mask = []
|
| 65 |
+
batch_entropy_mask = []
|
| 66 |
+
batch_timesteps = []
|
| 67 |
+
batch_state_ends_mask = []
|
| 68 |
+
batch_engine_log_probs = []
|
| 69 |
+
batch_rewards = []
|
| 70 |
+
for root in roots:
|
| 71 |
+
rollout_id = root.id
|
| 72 |
+
self.debug_path_list.append(
|
| 73 |
+
"mgid:" + str(rollout_id) + "_agent_id:" + agent_id
|
| 74 |
+
)
|
| 75 |
+
rollout_ids.append(rollout_id)
|
| 76 |
+
crn_ids.append(root.crn_id)
|
| 77 |
+
chat, rewards = get_main_chat_list_and_rewards(agent_id=agent_id, root=root)
|
| 78 |
+
(
|
| 79 |
+
input_ids,
|
| 80 |
+
action_mask,
|
| 81 |
+
entropy_mask,
|
| 82 |
+
timesteps,
|
| 83 |
+
state_ends_mask,
|
| 84 |
+
engine_log_probs,
|
| 85 |
+
) = process_training_chat(
|
| 86 |
+
tokenizer=self.tokenizer,
|
| 87 |
+
chat_history=chat,
|
| 88 |
+
entropy_mask_regex=self.entropy_mask_regex,
|
| 89 |
+
exploration_prompts_to_remove=self.exploration_prompts_to_remove,
|
| 90 |
+
)
|
| 91 |
+
batch_input_ids.append(input_ids)
|
| 92 |
+
batch_action_mask.append(action_mask)
|
| 93 |
+
batch_entropy_mask.append(entropy_mask)
|
| 94 |
+
batch_timesteps.append(timesteps)
|
| 95 |
+
batch_state_ends_mask.append(state_ends_mask)
|
| 96 |
+
batch_engine_log_probs.append(engine_log_probs)
|
| 97 |
+
batch_rewards.append(rewards)
|
| 98 |
+
|
| 99 |
+
trajectory_batch = TrajectoryBatch(
|
| 100 |
+
rollout_ids=torch.tensor(rollout_ids, dtype=torch.int32),
|
| 101 |
+
crn_ids=torch.tensor(crn_ids, dtype=torch.int32),
|
| 102 |
+
agent_ids=[agent_id] * len(rollout_ids),
|
| 103 |
+
batch_input_ids=batch_input_ids,
|
| 104 |
+
batch_action_mask=batch_action_mask,
|
| 105 |
+
batch_entropy_mask=batch_entropy_mask,
|
| 106 |
+
batch_timesteps=batch_timesteps,
|
| 107 |
+
batch_state_ends_mask=batch_state_ends_mask,
|
| 108 |
+
batch_rewards=batch_rewards,
|
| 109 |
+
batch_engine_log_probs=batch_engine_log_probs,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# Get Advantages
|
| 113 |
+
batch_advantages: torch.FloatTensor = (
|
| 114 |
+
self.get_advantages_with_critic_gradient_accumulation(trajectory_batch)
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# Discount state visitation (the mathematically correct way)
|
| 118 |
+
if not self.skip_discounted_state_visitation:
|
| 119 |
+
for i in range(len(batch_advantages)):
|
| 120 |
+
batch_advantages[i] = get_discounted_state_visitation_credits(
|
| 121 |
+
batch_advantages[i].unsqueeze(0),
|
| 122 |
+
self.discount_factor,
|
| 123 |
+
).squeeze(0)
|
| 124 |
+
|
| 125 |
+
self.training_data[agent_id] = TrainingData(
|
| 126 |
+
agent_id=agent_id,
|
| 127 |
+
main_data=trajectory_batch,
|
| 128 |
+
main_advantages=batch_advantages,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
def receive_advantage_data(self, advantage_packets: list[AdvantagePacket]):
|
| 132 |
+
"""
|
| 133 |
+
This trainer ignores the advantages of the other trainers.
|
| 134 |
+
"""
|
| 135 |
+
for agent_id, agent_data in self.training_data.items():
|
| 136 |
+
self.training_data[agent_id] = agent_data.main_data
|
| 137 |
+
self.training_data[agent_id].batch_credits = agent_data.main_advantages
|
| 138 |
+
|
| 139 |
+
def share_advantage_data(self) -> list[AdvantagePacket]:
|
| 140 |
+
"""
|
| 141 |
+
Share the advantage data with other agents.
|
| 142 |
+
Returns:
|
| 143 |
+
AdvantagePacket: The advantage packet containing the agent's advantages.
|
| 144 |
+
"""
|
| 145 |
+
logger.info(f"Sharing advantage data.")
|
| 146 |
+
advantage_packets = []
|
| 147 |
+
for agent_id, agent_data in self.training_data.items():
|
| 148 |
+
advantage_packets.append(
|
| 149 |
+
AdvantagePacket(
|
| 150 |
+
agent_id=agent_id,
|
| 151 |
+
rollout_ids=agent_data.main_data.rollout_ids,
|
| 152 |
+
main_advantages=agent_data.main_advantages,
|
| 153 |
+
)
|
| 154 |
+
)
|
| 155 |
+
return advantage_packets
|
src_code_for_reproducibility/training/training_data_utils.py
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Literal, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 6 |
+
|
| 7 |
+
from mllm.markov_games.rollout_tree import (
|
| 8 |
+
ChatTurn,
|
| 9 |
+
RolloutTreeBranchNode,
|
| 10 |
+
RolloutTreeNode,
|
| 11 |
+
RolloutTreeRootNode,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class AdvantagePacket:
|
| 17 |
+
agent_id: str
|
| 18 |
+
rollout_ids: torch.IntTensor # (B,)
|
| 19 |
+
# list-of-tensors
|
| 20 |
+
main_advantages: list[torch.FloatTensor]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class TrainingChatTurn:
|
| 24 |
+
# TODO: simplify by making this a child of ChatTurn
|
| 25 |
+
"""
|
| 26 |
+
This class contains the chat turns for a single agent.
|
| 27 |
+
It is like ChatTurn, but with the time step added.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
time_step: int,
|
| 33 |
+
role: str,
|
| 34 |
+
agent_id: str,
|
| 35 |
+
content: str,
|
| 36 |
+
chat_template_token_ids: list[int],
|
| 37 |
+
reasoning_content: str,
|
| 38 |
+
is_state_end: bool,
|
| 39 |
+
out_token_ids: Optional[list[int]] = None,
|
| 40 |
+
log_probs: Optional[list[float]] = None,
|
| 41 |
+
) -> None:
|
| 42 |
+
self.time_step = time_step
|
| 43 |
+
self.role = role
|
| 44 |
+
self.agent_id = agent_id
|
| 45 |
+
self.content = content
|
| 46 |
+
self.chat_template_token_ids = chat_template_token_ids
|
| 47 |
+
self.reasoning_content = reasoning_content
|
| 48 |
+
self.is_state_end = is_state_end
|
| 49 |
+
self.out_token_ids = out_token_ids
|
| 50 |
+
self.log_probs = log_probs
|
| 51 |
+
|
| 52 |
+
def dict(self):
|
| 53 |
+
return {
|
| 54 |
+
"time_step": self.time_step,
|
| 55 |
+
"role": self.role,
|
| 56 |
+
"agent_id": self.agent_id,
|
| 57 |
+
"content": self.content,
|
| 58 |
+
"chat_template_token_ids": self.chat_template_token_ids,
|
| 59 |
+
"reasoning_content": self.reasoning_content,
|
| 60 |
+
"is_state_end": self.is_state_end,
|
| 61 |
+
"out_token_ids": self.out_token_ids,
|
| 62 |
+
"log_probs": self.log_probs,
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_main_chat_list_and_rewards(
|
| 67 |
+
agent_id: str, root: RolloutTreeRootNode | RolloutTreeNode
|
| 68 |
+
) -> Tuple[list[TrainingChatTurn], torch.FloatTensor]:
|
| 69 |
+
"""
|
| 70 |
+
This method traverses a rollout tree and returns a the list of ChatTurn
|
| 71 |
+
for an agent. If it encounters a branch node, it follows the main path.
|
| 72 |
+
"""
|
| 73 |
+
# TODO; extend for all trees, not just linear
|
| 74 |
+
if isinstance(root, RolloutTreeRootNode):
|
| 75 |
+
current_node = root.child
|
| 76 |
+
else:
|
| 77 |
+
current_node = root
|
| 78 |
+
|
| 79 |
+
chat = []
|
| 80 |
+
rewards = []
|
| 81 |
+
while current_node is not None:
|
| 82 |
+
if isinstance(current_node, RolloutTreeBranchNode):
|
| 83 |
+
current_node = current_node.main_child
|
| 84 |
+
reward: float = current_node.step_log.simulation_step_log.rewards[agent_id]
|
| 85 |
+
rewards.append(reward)
|
| 86 |
+
chat_turns: list[TrainingChatTurn] = current_node.step_log.action_logs[
|
| 87 |
+
agent_id
|
| 88 |
+
].chat_turns
|
| 89 |
+
chat_turns = [
|
| 90 |
+
TrainingChatTurn(time_step=current_node.time_step, **turn.model_dump())
|
| 91 |
+
for turn in chat_turns
|
| 92 |
+
]
|
| 93 |
+
chat.extend(chat_turns)
|
| 94 |
+
current_node = current_node.child
|
| 95 |
+
return chat, torch.FloatTensor(rewards)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def get_tokenwise_credits(
|
| 99 |
+
# B := batch size, S := number of tokens / seq. length, T := number of states. `j` stands for jagged (see pytorch nested tensors.)
|
| 100 |
+
batch_timesteps: torch.IntTensor | torch.Tensor, # (B, jS),
|
| 101 |
+
batch_credits: torch.FloatTensor | torch.Tensor, # (B, jT)
|
| 102 |
+
) -> torch.FloatTensor | torch.Tensor: # (B, jS)
|
| 103 |
+
"""
|
| 104 |
+
TOWRITE
|
| 105 |
+
"""
|
| 106 |
+
# TODO vectorize this code
|
| 107 |
+
batch_token_credits = []
|
| 108 |
+
for credits, timesteps in zip(batch_credits, batch_timesteps):
|
| 109 |
+
token_credits = torch.zeros_like(
|
| 110 |
+
timesteps,
|
| 111 |
+
dtype=credits.dtype,
|
| 112 |
+
device=timesteps.device,
|
| 113 |
+
)
|
| 114 |
+
for idx, credit in enumerate(credits):
|
| 115 |
+
token_credits[timesteps == idx] = credit
|
| 116 |
+
batch_token_credits.append(token_credits)
|
| 117 |
+
return batch_token_credits
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
@dataclass
|
| 121 |
+
class TrajectoryBatch:
|
| 122 |
+
"""
|
| 123 |
+
Tensorized batch of trajectories using list-of-tensors for jagged dimensions.
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
# B := batch size, S := number of tokens / seq. length, T := number of states.
|
| 127 |
+
rollout_ids: torch.IntTensor # (B,)
|
| 128 |
+
crn_ids: torch.IntTensor # (B,)
|
| 129 |
+
agent_ids: list[str] # (B,)
|
| 130 |
+
batch_input_ids: list[torch.LongTensor] # List[(jS,)]
|
| 131 |
+
batch_action_mask: list[torch.BoolTensor] # List[(jS,)]
|
| 132 |
+
batch_entropy_mask: list[torch.BoolTensor] # List[(jS,)]
|
| 133 |
+
batch_timesteps: list[torch.IntTensor] # List[(jS,)]
|
| 134 |
+
batch_state_ends_mask: list[torch.BoolTensor] # List[(jS,)]
|
| 135 |
+
batch_engine_log_probs: Optional[list[torch.FloatTensor]] # List[(jS,)]
|
| 136 |
+
batch_rewards: list[torch.FloatTensor] # List[(jT,)]
|
| 137 |
+
batch_credits: Optional[list[torch.FloatTensor]] = None # List[(jS,)]
|
| 138 |
+
|
| 139 |
+
def __post_init__(self):
|
| 140 |
+
"""
|
| 141 |
+
Validate per-sample consistency.
|
| 142 |
+
"""
|
| 143 |
+
B = self.rollout_ids.shape[0]
|
| 144 |
+
assert (
|
| 145 |
+
self.crn_ids.shape[0] == B
|
| 146 |
+
), "RNG IDs must have length equal to batch size."
|
| 147 |
+
assert (
|
| 148 |
+
len(self.agent_ids) == B
|
| 149 |
+
), "agent_ids must have length equal to batch size."
|
| 150 |
+
assert (
|
| 151 |
+
len(self.batch_input_ids)
|
| 152 |
+
== len(self.batch_action_mask)
|
| 153 |
+
== len(self.batch_entropy_mask)
|
| 154 |
+
== len(self.batch_timesteps)
|
| 155 |
+
== len(self.batch_state_ends_mask)
|
| 156 |
+
== len(self.batch_engine_log_probs)
|
| 157 |
+
== len(self.batch_rewards)
|
| 158 |
+
== B
|
| 159 |
+
), "Jagged lists must all have length equal to batch size."
|
| 160 |
+
|
| 161 |
+
for b in range(B):
|
| 162 |
+
nb_rewards = int(self.batch_rewards[b].shape[0])
|
| 163 |
+
nb_timesteps = int(torch.max(self.batch_timesteps[b]).item()) + 1
|
| 164 |
+
assert (
|
| 165 |
+
nb_rewards == nb_timesteps
|
| 166 |
+
), "Number of rewards and timesteps mismatch."
|
| 167 |
+
assert (
|
| 168 |
+
self.batch_input_ids[b].shape[0]
|
| 169 |
+
== self.batch_action_mask[b].shape[0]
|
| 170 |
+
== self.batch_entropy_mask[b].shape[0]
|
| 171 |
+
== self.batch_engine_log_probs[b].shape[0]
|
| 172 |
+
== self.batch_timesteps[b].shape[0]
|
| 173 |
+
), "Tensors must have the same shape along the jagged dimension."
|
| 174 |
+
assert (
|
| 175 |
+
int(self.batch_state_ends_mask[b].sum())
|
| 176 |
+
== self.batch_rewards[b].shape[0]
|
| 177 |
+
), "Number of rewards must match number of state ends."
|
| 178 |
+
|
| 179 |
+
"""
|
| 180 |
+
Entries:
|
| 181 |
+
Here, we ignore the batch dimension.
|
| 182 |
+
input_ids:
|
| 183 |
+
All of the tokens of both the user and the assistant, flattened.
|
| 184 |
+
action_mask:
|
| 185 |
+
Set to true on the tokens of the assistant (tokens generated by the model).
|
| 186 |
+
timesteps:
|
| 187 |
+
Therefore, max(timesteps) = Ns - 1.
|
| 188 |
+
state_ends_idx:
|
| 189 |
+
Indices of the tokens at which state descriptions end.
|
| 190 |
+
rewards:
|
| 191 |
+
rewards[t] := R_t(s_t, a_t)
|
| 192 |
+
Example:
|
| 193 |
+
position: "0 1 2 3 4 5 6 7 8 9 10 11 12 13 14"
|
| 194 |
+
input_ids: "U U U a a a U a U a a a U U U" (U := User, a := Assistant)
|
| 195 |
+
action_mask: "x x x ✓ ✓ ✓ x ✓ x ✓ ✓ ✓ x x x"
|
| 196 |
+
timestep: "0 0 0 0 0 0 1 1 1 1 1 1 2 2 2"
|
| 197 |
+
state_ends_dx: [2, 6, 14]
|
| 198 |
+
rewards: [r0, r1, r2]
|
| 199 |
+
"""
|
| 200 |
+
|
| 201 |
+
def __getitem__(self, key) -> "TrajectoryBatch":
|
| 202 |
+
if isinstance(key, slice):
|
| 203 |
+
return TrajectoryBatch(
|
| 204 |
+
rollout_ids=self.rollout_ids.__getitem__(key),
|
| 205 |
+
crn_ids=self.crn_ids.__getitem__(key),
|
| 206 |
+
agent_ids=self.agent_ids[key],
|
| 207 |
+
batch_input_ids=self.batch_input_ids[key],
|
| 208 |
+
batch_action_mask=self.batch_action_mask[key],
|
| 209 |
+
batch_entropy_mask=self.batch_entropy_mask[key],
|
| 210 |
+
batch_timesteps=self.batch_timesteps[key],
|
| 211 |
+
batch_state_ends_mask=self.batch_state_ends_mask[key],
|
| 212 |
+
batch_engine_log_probs=self.batch_engine_log_probs[key],
|
| 213 |
+
batch_rewards=self.batch_rewards[key],
|
| 214 |
+
batch_credits=self.batch_credits[key] if self.batch_credits else None,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
def __len__(self):
|
| 218 |
+
return len(self.batch_input_ids)
|
| 219 |
+
|
| 220 |
+
def to(self, device):
|
| 221 |
+
self.rollout_ids = self.rollout_ids.to(device)
|
| 222 |
+
self.crn_ids = self.crn_ids.to(device)
|
| 223 |
+
self.batch_input_ids = [t.to(device) for t in self.batch_input_ids]
|
| 224 |
+
self.batch_action_mask = [t.to(device) for t in self.batch_action_mask]
|
| 225 |
+
self.batch_entropy_mask = [t.to(device) for t in self.batch_entropy_mask]
|
| 226 |
+
self.batch_timesteps = [t.to(device) for t in self.batch_timesteps]
|
| 227 |
+
self.batch_state_ends_mask = [t.to(device) for t in self.batch_state_ends_mask]
|
| 228 |
+
self.batch_engine_log_probs = [
|
| 229 |
+
t.to(device) for t in self.batch_engine_log_probs
|
| 230 |
+
]
|
| 231 |
+
self.batch_rewards = [t.to(device) for t in self.batch_rewards]
|
| 232 |
+
self.batch_credits = (
|
| 233 |
+
[t.to(device) for t in self.batch_credits] if self.batch_credits else None
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
def get_padded_tensors_for_critic(self):
|
| 237 |
+
"""
|
| 238 |
+
Returns:
|
| 239 |
+
padded_batch_input_ids: (B, P)
|
| 240 |
+
padded_batch_state_ends_mask: (B, P)
|
| 241 |
+
timestep_counts: (B,) tensor of ints indicating number of states per sample
|
| 242 |
+
"""
|
| 243 |
+
padded_batch_input_ids = pad_sequence(
|
| 244 |
+
self.batch_input_ids, batch_first=True, padding_value=0
|
| 245 |
+
)
|
| 246 |
+
padded_batch_state_ends_mask = pad_sequence(
|
| 247 |
+
self.batch_state_ends_mask, batch_first=True, padding_value=0
|
| 248 |
+
).bool()
|
| 249 |
+
# number of states equals number of True in state_ends_mask
|
| 250 |
+
timestep_counts = torch.tensor(
|
| 251 |
+
[int(mask.sum().item()) for mask in self.batch_state_ends_mask],
|
| 252 |
+
device=padded_batch_input_ids.device,
|
| 253 |
+
dtype=torch.long,
|
| 254 |
+
)
|
| 255 |
+
return padded_batch_input_ids, padded_batch_state_ends_mask, timestep_counts
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
timestep = int
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
@dataclass
|
| 262 |
+
class PaddedTensorTrainingBatch:
|
| 263 |
+
batch_input_ids: torch.LongTensor | torch.Tensor
|
| 264 |
+
batch_action_mask: torch.BoolTensor | torch.Tensor
|
| 265 |
+
batch_entropy_mask: Optional[torch.BoolTensor | torch.Tensor]
|
| 266 |
+
batch_credits: torch.FloatTensor | torch.Tensor
|
| 267 |
+
batch_engine_log_probs: torch.FloatTensor | torch.Tensor
|
| 268 |
+
batch_timesteps: torch.IntTensor | torch.Tensor
|
| 269 |
+
|
| 270 |
+
def __len__(self):
|
| 271 |
+
return self.batch_input_ids.shape[0]
|
| 272 |
+
|
| 273 |
+
def to(self, device):
|
| 274 |
+
self.batch_input_ids = self.batch_input_ids.to(device)
|
| 275 |
+
self.batch_action_mask = self.batch_action_mask.to(device)
|
| 276 |
+
self.batch_entropy_mask = self.batch_entropy_mask.to(device)
|
| 277 |
+
self.batch_credits = self.batch_credits.to(device)
|
| 278 |
+
self.batch_engine_log_probs = self.batch_engine_log_probs.to(device)
|
| 279 |
+
self.batch_timesteps = self.batch_timesteps.to(device)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
@dataclass
|
| 283 |
+
class TrainingBatch:
|
| 284 |
+
rollout_ids: torch.IntTensor | torch.Tensor # (B,)
|
| 285 |
+
batch_input_ids: list[torch.LongTensor] # List[(jS,)]
|
| 286 |
+
batch_action_mask: list[torch.BoolTensor] # List[(jS,)]
|
| 287 |
+
batch_entropy_mask: Optional[list[torch.BoolTensor]] # List[(jS,)]
|
| 288 |
+
batch_credits: list[torch.FloatTensor] # List[(jS,)]
|
| 289 |
+
batch_engine_log_probs: list[torch.FloatTensor] # List[(jS,)]
|
| 290 |
+
batch_timesteps: list[torch.IntTensor] # List[(jS,)]
|
| 291 |
+
|
| 292 |
+
def __post_init__(self):
|
| 293 |
+
# Put everything in the right device
|
| 294 |
+
# self.rollout_ids = self.rollout_ids.to("cuda" if torch.cuda.is_available() else "cpu")
|
| 295 |
+
# self.batch_input_ids = self.batch_input_ids.to("cuda" if torch.cuda.is_available() else "cpu")
|
| 296 |
+
# self.batch_action_mask = self.batch_action_mask.to("cuda" if torch.cuda.is_available() else "cpu")
|
| 297 |
+
# self.batch_credits = self.batch_credits.to("cuda" if torch.cuda.is_available() else "cpu")
|
| 298 |
+
# Ensure batch dimension is present
|
| 299 |
+
assert (
|
| 300 |
+
len(self.batch_input_ids)
|
| 301 |
+
== len(self.batch_action_mask)
|
| 302 |
+
== len(self.batch_entropy_mask)
|
| 303 |
+
== len(self.batch_credits)
|
| 304 |
+
== len(self.batch_engine_log_probs)
|
| 305 |
+
== len(self.batch_timesteps)
|
| 306 |
+
== self.rollout_ids.shape[0]
|
| 307 |
+
), "Jagged lists must all have length equal to batch size."
|
| 308 |
+
for inp, mask, cred, engine_log_prob, timestep in zip(
|
| 309 |
+
self.batch_input_ids,
|
| 310 |
+
self.batch_action_mask,
|
| 311 |
+
self.batch_credits,
|
| 312 |
+
self.batch_engine_log_probs,
|
| 313 |
+
self.batch_timesteps,
|
| 314 |
+
):
|
| 315 |
+
assert (
|
| 316 |
+
inp.shape[0]
|
| 317 |
+
== mask.shape[0]
|
| 318 |
+
== cred.shape[0]
|
| 319 |
+
== engine_log_prob.shape[0]
|
| 320 |
+
== timestep.shape[0]
|
| 321 |
+
), "Tensors must have the same shapes along the jagged dimension."
|
| 322 |
+
|
| 323 |
+
def __getitem__(self, key) -> "TrainingBatch":
|
| 324 |
+
if isinstance(key, slice):
|
| 325 |
+
return TrainingBatch(
|
| 326 |
+
rollout_ids=self.rollout_ids.__getitem__(key),
|
| 327 |
+
batch_input_ids=self.batch_input_ids[key],
|
| 328 |
+
batch_action_mask=self.batch_action_mask[key],
|
| 329 |
+
batch_entropy_mask=self.batch_entropy_mask[key],
|
| 330 |
+
batch_credits=self.batch_credits[key],
|
| 331 |
+
batch_engine_log_probs=self.batch_engine_log_probs[key],
|
| 332 |
+
batch_timesteps=self.batch_timesteps[key],
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
def __len__(self):
|
| 336 |
+
return len(self.batch_input_ids)
|
| 337 |
+
|
| 338 |
+
def to(self, device):
|
| 339 |
+
self.rollout_ids = self.rollout_ids.to(device)
|
| 340 |
+
self.batch_input_ids = [t.to(device) for t in self.batch_input_ids]
|
| 341 |
+
self.batch_action_mask = [t.to(device) for t in self.batch_action_mask]
|
| 342 |
+
self.batch_entropy_mask = [t.to(device) for t in self.batch_entropy_mask]
|
| 343 |
+
self.batch_credits = [t.to(device) for t in self.batch_credits]
|
| 344 |
+
self.batch_engine_log_probs = [
|
| 345 |
+
t.to(device) for t in self.batch_engine_log_probs
|
| 346 |
+
]
|
| 347 |
+
self.batch_timesteps = [t.to(device) for t in self.batch_timesteps]
|
| 348 |
+
|
| 349 |
+
def get_padded_tensors(self, padding: float = 0.0):
|
| 350 |
+
"""
|
| 351 |
+
TOWRITE
|
| 352 |
+
Always pad to the right.
|
| 353 |
+
"""
|
| 354 |
+
padded_batch_input_ids = pad_sequence(
|
| 355 |
+
self.batch_input_ids, batch_first=True, padding_value=int(padding)
|
| 356 |
+
)
|
| 357 |
+
padded_batch_action_mask = pad_sequence(
|
| 358 |
+
[m.to(dtype=torch.bool) for m in self.batch_action_mask],
|
| 359 |
+
batch_first=True,
|
| 360 |
+
padding_value=False,
|
| 361 |
+
)
|
| 362 |
+
padded_batch_entropy_mask = pad_sequence(
|
| 363 |
+
self.batch_entropy_mask, batch_first=True, padding_value=False
|
| 364 |
+
)
|
| 365 |
+
padded_batch_credits = pad_sequence(
|
| 366 |
+
self.batch_credits, batch_first=True, padding_value=float(padding)
|
| 367 |
+
)
|
| 368 |
+
padded_batch_engine_log_probs = pad_sequence(
|
| 369 |
+
self.batch_engine_log_probs, batch_first=True, padding_value=float(padding)
|
| 370 |
+
)
|
| 371 |
+
padded_batch_timesteps = pad_sequence(
|
| 372 |
+
self.batch_timesteps, batch_first=True, padding_value=0
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
return PaddedTensorTrainingBatch(
|
| 376 |
+
padded_batch_input_ids,
|
| 377 |
+
padded_batch_action_mask,
|
| 378 |
+
padded_batch_entropy_mask,
|
| 379 |
+
padded_batch_credits,
|
| 380 |
+
padded_batch_engine_log_probs,
|
| 381 |
+
padded_batch_timesteps,
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
def append(self, other: "TrainingBatch"):
|
| 385 |
+
self.rollout_ids = torch.cat([self.rollout_ids, other.rollout_ids])
|
| 386 |
+
self.batch_input_ids.extend(other.batch_input_ids)
|
| 387 |
+
self.batch_action_mask.extend(other.batch_action_mask)
|
| 388 |
+
self.batch_entropy_mask.extend(other.batch_entropy_mask)
|
| 389 |
+
self.batch_credits.extend(other.batch_credits)
|
| 390 |
+
self.batch_engine_log_probs.extend(other.batch_engine_log_probs)
|
| 391 |
+
self.batch_timesteps.extend(other.batch_timesteps)
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
timestep = int
|
src_code_for_reproducibility/utils/get_stochastic_game_lengths.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
def get_stochastic_game_lengths(
|
| 4 |
+
max_length,
|
| 5 |
+
nb_games,
|
| 6 |
+
continuation_prob,
|
| 7 |
+
same_length_batch=False
|
| 8 |
+
):
|
| 9 |
+
"""
|
| 10 |
+
Generates stochastic game lengths based on a geometric distribution.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
max_length (int): The maximum length a game can have.
|
| 14 |
+
nb_games (int): The number of games to generate lengths for.
|
| 15 |
+
continuation_prob (float): The probability of the game continuing after each round.
|
| 16 |
+
same_length_batch (bool): If True, all games will have the same length.
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
Array: An array of game lengths.
|
| 20 |
+
"""
|
| 21 |
+
if continuation_prob == 1:
|
| 22 |
+
return [max_length] * nb_games
|
| 23 |
+
if same_length_batch:
|
| 24 |
+
length = np.random.geometric(1 - continuation_prob, 1)
|
| 25 |
+
game_lengths = np.repeat(length, nb_games)
|
| 26 |
+
else:
|
| 27 |
+
game_lengths = np.random.geometric(1 - continuation_prob, nb_games)
|
| 28 |
+
|
| 29 |
+
game_lengths = np.where(game_lengths > max_length, max_length, game_lengths)
|
| 30 |
+
return game_lengths.tolist()
|
src_code_for_reproducibility/utils/resource_context.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import time
|
| 3 |
+
from contextlib import contextmanager
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def vram_usage():
|
| 9 |
+
output = ""
|
| 10 |
+
for i in range(torch.cuda.device_count()):
|
| 11 |
+
gpu_memory_allocated = torch.cuda.memory_allocated(i) / (
|
| 12 |
+
1024**3
|
| 13 |
+
) # Convert bytes to GB
|
| 14 |
+
gpu_memory_reserved = torch.cuda.memory_reserved(i) / (
|
| 15 |
+
1024**3
|
| 16 |
+
) # Convert bytes to GB
|
| 17 |
+
output += f"GPU {i}: Memory Allocated: {gpu_memory_allocated:.2f} GB, Memory Reserved: {gpu_memory_reserved:.2f} GB"
|
| 18 |
+
return output
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def ram_usage():
|
| 22 |
+
import psutil
|
| 23 |
+
|
| 24 |
+
process = psutil.Process()
|
| 25 |
+
memory_info = process.memory_info()
|
| 26 |
+
ram_used = memory_info.rss / (1024**3) # Convert bytes to GB
|
| 27 |
+
return f"RAM Usage: {ram_used:.2f} GB"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@contextmanager
|
| 31 |
+
def resource_logger_context(logger: logging.Logger, task_description: str):
|
| 32 |
+
"""
|
| 33 |
+
Context manager to log the resource usage of the current task.
|
| 34 |
+
Args:
|
| 35 |
+
logger: The logger to use to log the resource usage.
|
| 36 |
+
task_description: The description of the task to log.
|
| 37 |
+
Returns:
|
| 38 |
+
None
|
| 39 |
+
"""
|
| 40 |
+
try:
|
| 41 |
+
initial_time = time.time()
|
| 42 |
+
# Assume CUDA is available and use device 0 only
|
| 43 |
+
total_mem_bytes = torch.cuda.get_device_properties(0).total_memory
|
| 44 |
+
initial_total_bytes = (
|
| 45 |
+
torch.cuda.memory_allocated(0) + torch.cuda.memory_reserved(0)
|
| 46 |
+
)
|
| 47 |
+
torch.cuda.reset_peak_memory_stats(0)
|
| 48 |
+
yield None
|
| 49 |
+
finally:
|
| 50 |
+
final_time = time.time()
|
| 51 |
+
# Ensure kernels within the block are accounted for
|
| 52 |
+
torch.cuda.synchronize()
|
| 53 |
+
|
| 54 |
+
# Compute metrics
|
| 55 |
+
final_allocated_bytes = torch.cuda.memory_allocated(0)
|
| 56 |
+
final_reserved_bytes = torch.cuda.memory_reserved(0)
|
| 57 |
+
final_total_bytes = final_allocated_bytes + final_reserved_bytes
|
| 58 |
+
|
| 59 |
+
delta_vram_percent_total = (
|
| 60 |
+
100 * (final_total_bytes - initial_total_bytes) / total_mem_bytes
|
| 61 |
+
if total_mem_bytes
|
| 62 |
+
else 0.0
|
| 63 |
+
)
|
| 64 |
+
current_percent_vram_taken = (
|
| 65 |
+
100 * final_total_bytes / total_mem_bytes if total_mem_bytes else 0.0
|
| 66 |
+
)
|
| 67 |
+
block_peak_percent = (
|
| 68 |
+
100 * torch.cuda.max_memory_allocated(0) / total_mem_bytes
|
| 69 |
+
if total_mem_bytes
|
| 70 |
+
else 0.0
|
| 71 |
+
)
|
| 72 |
+
delta_time_str = time.strftime(
|
| 73 |
+
'%H:%M:%S', time.gmtime(final_time - initial_time)
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
logger.info(
|
| 77 |
+
f"For task: {task_description}, ΔVRAM % (total): {delta_vram_percent_total:.2f}%, Current % of VRAM taken: {current_percent_vram_taken:.2f}%, Block Peak % of device VRAM: {block_peak_percent:.2f}%, ΔTime: {delta_time_str}"
|
| 78 |
+
)
|
src_code_for_reproducibility/utils/rollout_tree_chat_htmls.py
ADDED
|
@@ -0,0 +1,1921 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
from mllm.utils.rollout_tree_gather_utils import *
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def html_from_chat_turns(chat_turns: List[ChatTurnLog]) -> str:
|
| 8 |
+
"""
|
| 9 |
+
Render chat turns as a single, wrapping sequence of messages in time order.
|
| 10 |
+
Keep badge and message bubble styles, include time on every badge and
|
| 11 |
+
include rewards on assistant badges. Each message is individually
|
| 12 |
+
hide/show by click; when hidden, only the badge remains and "(...)" is
|
| 13 |
+
shown inline (not inside a bubble).
|
| 14 |
+
"""
|
| 15 |
+
import html
|
| 16 |
+
import re as _re
|
| 17 |
+
|
| 18 |
+
# Prepare ordering: sort by (time_step, original_index) to keep stable order within same step
|
| 19 |
+
indexed_turns = list(enumerate(chat_turns))
|
| 20 |
+
indexed_turns.sort(key=lambda t: (t[1].time_step, t[0]))
|
| 21 |
+
assistant_agents = sorted({t.agent_id for t in chat_turns if t.role == "assistant"})
|
| 22 |
+
enable_split_view = len(assistant_agents) == 2
|
| 23 |
+
|
| 24 |
+
# CSS styles (simplified layout; no time-step or agent-column backgrounds)
|
| 25 |
+
css = """
|
| 26 |
+
<style>
|
| 27 |
+
:root {
|
| 28 |
+
--font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
| 29 |
+
--bg: #ffffff;
|
| 30 |
+
--text: #1c0b00;
|
| 31 |
+
--muted-text: #2C3E50;
|
| 32 |
+
--accent-muted: #BDC3C7;
|
| 33 |
+
--accent-muted-2: #D0D7DE;
|
| 34 |
+
--panel-bg: #F8FAFC;
|
| 35 |
+
--reward-color: #3a2e00; /* dark text for reward pill */
|
| 36 |
+
--font-size: 14px;
|
| 37 |
+
--border-width: 2px;
|
| 38 |
+
--corner-radius: 6px;
|
| 39 |
+
--pill-radius-left: 999px 0 0 999px;
|
| 40 |
+
--pill-radius-right: 0 999px 999px 0;
|
| 41 |
+
--inset-shadow: 0 1px 0 rgba(0,0,0,0.03) inset;
|
| 42 |
+
|
| 43 |
+
/* Chat View Colors */
|
| 44 |
+
--alice-bg: #dcf8c6;
|
| 45 |
+
--alice-border: #0eb224;
|
| 46 |
+
--bob-bg: #ffe4cc;
|
| 47 |
+
--bob-border: #ef8323;
|
| 48 |
+
--user-bg: #f5f5f5;
|
| 49 |
+
--chat-bg: #ffffff;
|
| 50 |
+
}
|
| 51 |
+
body {
|
| 52 |
+
font-family: var(--font-family);
|
| 53 |
+
margin: 12px;
|
| 54 |
+
background-color: var(--bg);
|
| 55 |
+
color: var(--text);
|
| 56 |
+
font-size: var(--font-size);
|
| 57 |
+
line-height: 1.5;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
/* Chat View Styles */
|
| 61 |
+
#flow-chat {
|
| 62 |
+
max-width: 900px;
|
| 63 |
+
margin: 0 auto;
|
| 64 |
+
background: var(--chat-bg);
|
| 65 |
+
padding: 12px 16px 12px 8px;
|
| 66 |
+
border-radius: 8px;
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
.simultaneous-messages {
|
| 70 |
+
display: flex !important;
|
| 71 |
+
flex-direction: row !important;
|
| 72 |
+
flex-wrap: nowrap !important;
|
| 73 |
+
gap: 8px;
|
| 74 |
+
margin-bottom: 4px;
|
| 75 |
+
align-items: flex-start;
|
| 76 |
+
width: 100%;
|
| 77 |
+
overflow: hidden;
|
| 78 |
+
box-sizing: border-box;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
.simultaneous-messages .chat-message {
|
| 82 |
+
flex: 1 1 0 !important;
|
| 83 |
+
margin-bottom: 0 !important;
|
| 84 |
+
display: flex !important;
|
| 85 |
+
flex-direction: row !important;
|
| 86 |
+
align-items: flex-start !important;
|
| 87 |
+
margin-left: 0 !important;
|
| 88 |
+
min-width: 0 !important;
|
| 89 |
+
max-width: 50% !important;
|
| 90 |
+
gap: 0 !important;
|
| 91 |
+
overflow: hidden !important;
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
.simultaneous-messages .chat-message-content {
|
| 95 |
+
max-width: 100% !important;
|
| 96 |
+
width: 100%;
|
| 97 |
+
align-items: flex-start !important;
|
| 98 |
+
margin-left: 0 !important;
|
| 99 |
+
overflow: hidden !important;
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
.simultaneous-messages .chat-message.agent-alice {
|
| 103 |
+
justify-content: flex-start !important;
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
.simultaneous-messages .chat-message.agent-bob {
|
| 107 |
+
justify-content: flex-end !important;
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
.simultaneous-messages .chat-message.agent-alice .chat-message-content {
|
| 111 |
+
margin-left: 0 !important;
|
| 112 |
+
align-items: flex-start !important;
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
.simultaneous-messages .chat-message.agent-bob .chat-message-content {
|
| 116 |
+
margin-left: auto !important;
|
| 117 |
+
margin-right: 0 !important;
|
| 118 |
+
align-items: flex-end !important;
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
.simultaneous-messages .chat-bubble {
|
| 122 |
+
max-width: 100%;
|
| 123 |
+
word-break: break-word;
|
| 124 |
+
overflow-wrap: break-word;
|
| 125 |
+
box-sizing: border-box;
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
.simultaneous-messages .chat-message.agent-alice .chat-bubble {
|
| 129 |
+
border-radius: 10px;
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
.simultaneous-messages .chat-message.agent-bob .chat-bubble {
|
| 133 |
+
border-radius: 10px;
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
.simultaneous-messages .chat-message.agent-alice .chat-header {
|
| 137 |
+
justify-content: flex-start;
|
| 138 |
+
flex-shrink: 0;
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
.simultaneous-messages .chat-message.agent-bob .chat-header {
|
| 142 |
+
justify-content: flex-end;
|
| 143 |
+
flex-shrink: 0;
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
.simultaneous-messages .chat-reasoning {
|
| 147 |
+
max-width: 100%;
|
| 148 |
+
overflow-wrap: break-word;
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
.chat-message {
|
| 152 |
+
display: flex;
|
| 153 |
+
margin-bottom: 2px;
|
| 154 |
+
align-items: flex-end;
|
| 155 |
+
gap: 6px;
|
| 156 |
+
position: relative;
|
| 157 |
+
margin-left: 36px;
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
.chat-message.agent-alice {
|
| 161 |
+
margin-left: 0;
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
.chat-message.agent-alice::before {
|
| 165 |
+
left: 0;
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
.chat-message.role-user {
|
| 169 |
+
opacity: 0.7;
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
.chat-message::before {
|
| 173 |
+
content: '';
|
| 174 |
+
position: absolute;
|
| 175 |
+
left: -36px;
|
| 176 |
+
top: 0;
|
| 177 |
+
bottom: 0;
|
| 178 |
+
width: 36px;
|
| 179 |
+
pointer-events: auto;
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
.merge-btn {
|
| 183 |
+
position: absolute;
|
| 184 |
+
left: -30px;
|
| 185 |
+
top: 50%;
|
| 186 |
+
transform: translateY(-50%);
|
| 187 |
+
width: 26px;
|
| 188 |
+
height: 26px;
|
| 189 |
+
border-radius: 4px;
|
| 190 |
+
border: 1.5px solid var(--accent-muted);
|
| 191 |
+
background: white;
|
| 192 |
+
cursor: pointer;
|
| 193 |
+
font-size: var(--font-size);
|
| 194 |
+
opacity: 0;
|
| 195 |
+
display: flex;
|
| 196 |
+
align-items: center;
|
| 197 |
+
justify-content: center;
|
| 198 |
+
transition: opacity 0.2s ease, transform 0.1s ease;
|
| 199 |
+
padding: 0;
|
| 200 |
+
line-height: 1;
|
| 201 |
+
z-index: 10;
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
.chat-message:hover .merge-btn,
|
| 205 |
+
.merge-btn:hover {
|
| 206 |
+
opacity: 1;
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
.merge-btn:hover {
|
| 210 |
+
background: var(--panel-bg);
|
| 211 |
+
border-color: var(--accent-muted-2);
|
| 212 |
+
transform: translateY(-50%) scale(1.15);
|
| 213 |
+
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.15);
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
.merge-btn:active {
|
| 217 |
+
transform: translateY(-50%) scale(0.95);
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
.chat-message.agent-alice .merge-btn {
|
| 221 |
+
left: -30px;
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
.chat-message.role-user .merge-btn {
|
| 225 |
+
display: none !important;
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
.simultaneous-messages .merge-btn {
|
| 229 |
+
opacity: 0 !important;
|
| 230 |
+
pointer-events: none;
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
.simultaneous-messages {
|
| 234 |
+
padding: 6px 0 6px 0 !important;
|
| 235 |
+
margin-left: 0 !important;
|
| 236 |
+
margin-right: 0 !important;
|
| 237 |
+
position: relative !important;
|
| 238 |
+
background: transparent !important;
|
| 239 |
+
border-radius: 0 !important;
|
| 240 |
+
box-sizing: border-box !important;
|
| 241 |
+
overflow: visible !important;
|
| 242 |
+
max-width: 100% !important;
|
| 243 |
+
border: none !important;
|
| 244 |
+
transition: padding 0.2s ease !important;
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
.simultaneous-messages:hover {
|
| 248 |
+
padding-top: 40px !important;
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
.simultaneous-messages::before {
|
| 252 |
+
content: '⇅ Merged';
|
| 253 |
+
position: absolute;
|
| 254 |
+
left: 0 !important;
|
| 255 |
+
top: 8px !important;
|
| 256 |
+
font-size: var(--font-size);
|
| 257 |
+
font-weight: 500;
|
| 258 |
+
color: #888;
|
| 259 |
+
pointer-events: none;
|
| 260 |
+
opacity: 0;
|
| 261 |
+
transition: opacity 0.2s ease;
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
.simultaneous-messages:hover::before {
|
| 265 |
+
opacity: 1;
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
.unmerge-btn {
|
| 269 |
+
position: absolute !important;
|
| 270 |
+
right: 0 !important;
|
| 271 |
+
top: 6px !important;
|
| 272 |
+
width: 36px !important;
|
| 273 |
+
height: 28px !important;
|
| 274 |
+
border-radius: 5px !important;
|
| 275 |
+
border: 2px solid #d63031 !important;
|
| 276 |
+
background: white !important;
|
| 277 |
+
cursor: pointer !important;
|
| 278 |
+
font-size: var(--font-size) !important;
|
| 279 |
+
font-weight: bold !important;
|
| 280 |
+
color: #d63031 !important;
|
| 281 |
+
display: flex !important;
|
| 282 |
+
align-items: center !important;
|
| 283 |
+
justify-content: center !important;
|
| 284 |
+
transition: all 0.2s ease !important;
|
| 285 |
+
padding: 0 !important;
|
| 286 |
+
line-height: 1 !important;
|
| 287 |
+
z-index: 1000 !important;
|
| 288 |
+
flex: none !important;
|
| 289 |
+
pointer-events: auto !important;
|
| 290 |
+
box-shadow: 0 2px 6px rgba(214, 48, 49, 0.3) !important;
|
| 291 |
+
opacity: 0 !important;
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
.simultaneous-messages:hover .unmerge-btn {
|
| 295 |
+
opacity: 1 !important;
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
.unmerge-btn:hover {
|
| 299 |
+
background: #ffe5e5 !important;
|
| 300 |
+
border-color: #b71c1c !important;
|
| 301 |
+
transform: scale(1.1) !important;
|
| 302 |
+
box-shadow: 0 3px 8px rgba(214, 48, 49, 0.4) !important;
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
.unmerge-btn:active {
|
| 306 |
+
transform: scale(0.95) !important;
|
| 307 |
+
background: #ffcccc !important;
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
.chat-message-content {
|
| 311 |
+
max-width: 72%;
|
| 312 |
+
display: flex;
|
| 313 |
+
flex-direction: column;
|
| 314 |
+
gap: 2px;
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
.chat-message.agent-alice .chat-message-content {
|
| 318 |
+
align-items: flex-start;
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
.chat-message.agent-bob .chat-message-content {
|
| 322 |
+
align-items: flex-end;
|
| 323 |
+
margin-left: auto;
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
.chat-bubble {
|
| 327 |
+
padding: 6px 10px;
|
| 328 |
+
border-radius: 10px;
|
| 329 |
+
word-wrap: break-word;
|
| 330 |
+
position: relative;
|
| 331 |
+
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
|
| 332 |
+
line-height: 1.4;
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
.chat-message.agent-alice .chat-bubble {
|
| 336 |
+
background: var(--alice-bg);
|
| 337 |
+
border: 2px solid var(--alice-border);
|
| 338 |
+
border-radius: 10px 10px 10px 2px;
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
.chat-message.agent-bob .chat-bubble {
|
| 342 |
+
background: var(--bob-bg);
|
| 343 |
+
border: 2px solid var(--bob-border);
|
| 344 |
+
border-radius: 10px 10px 2px 10px;
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
.chat-message.role-user .chat-bubble {
|
| 348 |
+
background: var(--user-bg);
|
| 349 |
+
border: 2px solid #d0d0d0;
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
.chat-header {
|
| 353 |
+
display: flex;
|
| 354 |
+
align-items: center;
|
| 355 |
+
gap: 4px;
|
| 356 |
+
margin-bottom: 2px;
|
| 357 |
+
font-size: var(--font-size);
|
| 358 |
+
font-weight: 600;
|
| 359 |
+
line-height: 1.2;
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
.chat-message.agent-alice .chat-header {
|
| 363 |
+
color: var(--alice-border);
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
.chat-message.agent-bob .chat-header {
|
| 367 |
+
color: var(--bob-border);
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
.chat-timestamp {
|
| 371 |
+
font-size: var(--font-size);
|
| 372 |
+
color: var(--muted-text);
|
| 373 |
+
margin-top: 1px;
|
| 374 |
+
opacity: 0.75;
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
.chat-reward {
|
| 378 |
+
display: inline-flex;
|
| 379 |
+
align-items: center;
|
| 380 |
+
background: linear-gradient(90deg, #fffdf2 0%, #ffffff 75%);
|
| 381 |
+
color: #000000;
|
| 382 |
+
font-weight: 600;
|
| 383 |
+
font-size: var(--font-size);
|
| 384 |
+
padding: 1px 5px;
|
| 385 |
+
border-radius: 3px;
|
| 386 |
+
border: 1px solid #f4e6a8;
|
| 387 |
+
margin-left: 4px;
|
| 388 |
+
line-height: 1.3;
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
.chat-reasoning {
|
| 392 |
+
font-size: var(--font-size);
|
| 393 |
+
font-style: italic;
|
| 394 |
+
color: #555;
|
| 395 |
+
margin-bottom: 2px;
|
| 396 |
+
padding: 4px 8px;
|
| 397 |
+
background: rgba(0, 0, 0, 0.03);
|
| 398 |
+
border-radius: 5px;
|
| 399 |
+
cursor: pointer;
|
| 400 |
+
line-height: 1.3;
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
.chat-reasoning.collapsed .reasoning-text {
|
| 404 |
+
display: none;
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
.chat-reasoning.collapsed::after {
|
| 408 |
+
content: ' (click to expand)';
|
| 409 |
+
color: #777;
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
.chat-group-divider {
|
| 413 |
+
display: flex;
|
| 414 |
+
align-items: center;
|
| 415 |
+
gap: 8px;
|
| 416 |
+
width: 100%;
|
| 417 |
+
margin: 8px 0 4px 0;
|
| 418 |
+
position: relative;
|
| 419 |
+
cursor: pointer;
|
| 420 |
+
user-select: none;
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
.chat-group-divider::before,
|
| 424 |
+
.chat-group-divider::after {
|
| 425 |
+
content: "";
|
| 426 |
+
flex: 1 1 auto;
|
| 427 |
+
height: 2px;
|
| 428 |
+
background: linear-gradient(90deg, rgba(224,230,235,0), var(--accent-muted-2) 30%, var(--accent-muted-2) 70%, rgba(224,230,235,0));
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
.chat-group-label {
|
| 432 |
+
display: inline-block;
|
| 433 |
+
background: white;
|
| 434 |
+
padding: 2px 12px;
|
| 435 |
+
border-radius: 999px;
|
| 436 |
+
font-size: var(--font-size);
|
| 437 |
+
font-weight: 700;
|
| 438 |
+
color: var(--muted-text);
|
| 439 |
+
border: 1.5px solid var(--accent-muted);
|
| 440 |
+
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.08);
|
| 441 |
+
line-height: 1.4;
|
| 442 |
+
position: relative;
|
| 443 |
+
transition: background 0.2s ease;
|
| 444 |
+
}
|
| 445 |
+
|
| 446 |
+
.chat-group-divider:hover .chat-group-label {
|
| 447 |
+
background: var(--panel-bg);
|
| 448 |
+
}
|
| 449 |
+
|
| 450 |
+
.chat-group-label::before {
|
| 451 |
+
content: '▼ ';
|
| 452 |
+
font-size: 0.8em;
|
| 453 |
+
display: inline-block;
|
| 454 |
+
transition: transform 0.2s ease;
|
| 455 |
+
opacity: 0;
|
| 456 |
+
}
|
| 457 |
+
|
| 458 |
+
.chat-group-divider:hover .chat-group-label::before {
|
| 459 |
+
opacity: 1;
|
| 460 |
+
}
|
| 461 |
+
|
| 462 |
+
.chat-group-divider.collapsed .chat-group-label::before {
|
| 463 |
+
content: '▶ ';
|
| 464 |
+
opacity: 1;
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
.chat-group-divider.collapsed + * {
|
| 468 |
+
display: none !important;
|
| 469 |
+
}
|
| 470 |
+
|
| 471 |
+
/* Hide collapsed rounds in strong hide mode */
|
| 472 |
+
.strong-hide .chat-group-divider.collapsed {
|
| 473 |
+
display: none !important;
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
/* Chat view width control */
|
| 477 |
+
#flow-chat {
|
| 478 |
+
--chat-width: 900px;
|
| 479 |
+
max-width: var(--chat-width);
|
| 480 |
+
margin: 0 auto;
|
| 481 |
+
}
|
| 482 |
+
|
| 483 |
+
/* Hide user messages when toggle is on */
|
| 484 |
+
#flow-chat.hide-user-messages .chat-message.role-user {
|
| 485 |
+
display: none;
|
| 486 |
+
}
|
| 487 |
+
|
| 488 |
+
/* Hide rewards when hiding user messages */
|
| 489 |
+
#flow-chat.hide-user-messages .chat-reward {
|
| 490 |
+
display: none;
|
| 491 |
+
}
|
| 492 |
+
|
| 493 |
+
/* Round context annotations */
|
| 494 |
+
.round-context {
|
| 495 |
+
text-align: center;
|
| 496 |
+
margin: 4px auto;
|
| 497 |
+
max-width: 100%;
|
| 498 |
+
}
|
| 499 |
+
|
| 500 |
+
.round-context-edit {
|
| 501 |
+
min-height: 20px;
|
| 502 |
+
padding: 5px 10px;
|
| 503 |
+
border: 1.5px dashed var(--accent-muted);
|
| 504 |
+
border-radius: 6px;
|
| 505 |
+
background: #fafafa;
|
| 506 |
+
cursor: text;
|
| 507 |
+
transition: all 0.2s ease;
|
| 508 |
+
outline: none;
|
| 509 |
+
font-size: var(--font-size);
|
| 510 |
+
line-height: 1.3;
|
| 511 |
+
user-select: text;
|
| 512 |
+
-webkit-user-select: text;
|
| 513 |
+
-moz-user-select: text;
|
| 514 |
+
-ms-user-select: text;
|
| 515 |
+
}
|
| 516 |
+
|
| 517 |
+
.round-context-edit:focus {
|
| 518 |
+
border-style: solid;
|
| 519 |
+
border-color: var(--accent-muted-2);
|
| 520 |
+
background: #ffffff;
|
| 521 |
+
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
.round-context-edit:empty:before {
|
| 525 |
+
content: attr(data-placeholder);
|
| 526 |
+
color: #999;
|
| 527 |
+
font-style: italic;
|
| 528 |
+
}
|
| 529 |
+
|
| 530 |
+
.round-context-controls {
|
| 531 |
+
display: none;
|
| 532 |
+
justify-content: center;
|
| 533 |
+
gap: 4px;
|
| 534 |
+
margin-top: 4px;
|
| 535 |
+
flex-wrap: wrap;
|
| 536 |
+
}
|
| 537 |
+
|
| 538 |
+
.round-context-edit:focus + .round-context-controls,
|
| 539 |
+
.round-context-controls:hover,
|
| 540 |
+
.round-context:focus-within .round-context-controls {
|
| 541 |
+
display: flex;
|
| 542 |
+
}
|
| 543 |
+
|
| 544 |
+
.context-color-btn {
|
| 545 |
+
width: 22px;
|
| 546 |
+
height: 22px;
|
| 547 |
+
border-radius: 50%;
|
| 548 |
+
border: 1.5px solid #fff;
|
| 549 |
+
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.15);
|
| 550 |
+
cursor: pointer;
|
| 551 |
+
transition: transform 0.1s ease;
|
| 552 |
+
}
|
| 553 |
+
|
| 554 |
+
.context-color-btn:hover {
|
| 555 |
+
transform: scale(1.15);
|
| 556 |
+
}
|
| 557 |
+
|
| 558 |
+
.context-color-btn:active {
|
| 559 |
+
transform: scale(0.95);
|
| 560 |
+
}
|
| 561 |
+
|
| 562 |
+
/* Split agent context boxes */
|
| 563 |
+
.split-agent-context {
|
| 564 |
+
display: flex;
|
| 565 |
+
gap: 6px;
|
| 566 |
+
margin: 4px auto;
|
| 567 |
+
max-width: 100%;
|
| 568 |
+
align-items: flex-start;
|
| 569 |
+
}
|
| 570 |
+
|
| 571 |
+
.agent-context-box {
|
| 572 |
+
flex: 1;
|
| 573 |
+
min-width: 0;
|
| 574 |
+
position: relative;
|
| 575 |
+
}
|
| 576 |
+
|
| 577 |
+
.agent-context-box .round-context-edit {
|
| 578 |
+
margin: 0;
|
| 579 |
+
border-radius: 6px;
|
| 580 |
+
padding: 4px 8px;
|
| 581 |
+
min-height: 18px;
|
| 582 |
+
}
|
| 583 |
+
|
| 584 |
+
.agent-context-box.agent-alice .round-context-edit {
|
| 585 |
+
border-color: var(--alice-border);
|
| 586 |
+
background: rgba(14, 178, 36, 0.03);
|
| 587 |
+
}
|
| 588 |
+
|
| 589 |
+
.agent-context-box.agent-bob .round-context-edit {
|
| 590 |
+
border-color: var(--bob-border);
|
| 591 |
+
background: rgba(239, 131, 35, 0.03);
|
| 592 |
+
}
|
| 593 |
+
|
| 594 |
+
.agent-context-box.agent-alice .round-context-edit:focus {
|
| 595 |
+
border-color: var(--alice-border);
|
| 596 |
+
box-shadow: 0 2px 8px rgba(14, 178, 36, 0.2);
|
| 597 |
+
background: rgba(14, 178, 36, 0.05);
|
| 598 |
+
}
|
| 599 |
+
|
| 600 |
+
.agent-context-box.agent-bob .round-context-edit:focus {
|
| 601 |
+
border-color: var(--bob-border);
|
| 602 |
+
box-shadow: 0 2px 8px rgba(239, 131, 35, 0.2);
|
| 603 |
+
background: rgba(239, 131, 35, 0.05);
|
| 604 |
+
}
|
| 605 |
+
|
| 606 |
+
.agent-context-box .round-context-edit::before {
|
| 607 |
+
font-weight: 700;
|
| 608 |
+
font-size: var(--font-size);
|
| 609 |
+
margin-right: 5px;
|
| 610 |
+
letter-spacing: 0.2px;
|
| 611 |
+
}
|
| 612 |
+
|
| 613 |
+
.agent-context-box.agent-alice .round-context-edit::before {
|
| 614 |
+
content: 'Alice Prompt Summary:';
|
| 615 |
+
color: var(--alice-border);
|
| 616 |
+
}
|
| 617 |
+
|
| 618 |
+
.agent-context-box.agent-bob .round-context-edit::before {
|
| 619 |
+
content: 'Bob Prompt Summary:';
|
| 620 |
+
color: var(--bob-border);
|
| 621 |
+
}
|
| 622 |
+
|
| 623 |
+
/* Empty context boxes will be hidden by JavaScript when strong hide is enabled */
|
| 624 |
+
.messages-flow { display: block; }
|
| 625 |
+
.split-wrapper { display: flex; gap: 4px; align-items: flex-start; position: relative; }
|
| 626 |
+
.split-col { flex:1 1 0; min-width:0; }
|
| 627 |
+
/* In split view keep same inline density as linear view */
|
| 628 |
+
.split-col .chat-turn { display: inline; }
|
| 629 |
+
.split-wrapper.resizing { user-select: none; }
|
| 630 |
+
.split-resizer { width:4px; cursor: col-resize; flex:0 0 auto; align-self: stretch; position: relative; background: linear-gradient(90deg, rgba(224,230,235,0), var(--accent-muted-2) 30%, var(--accent-muted-2) 70%, rgba(224,230,235,0)); border-radius:2px; transition: background .15s ease, width .15s ease; }
|
| 631 |
+
.split-resizer:hover { background: linear-gradient(90deg, rgba(224,230,235,0), var(--accent-muted) 35%, var(--accent-muted) 65%, rgba(224,230,235,0)); }
|
| 632 |
+
.split-resizer.dragging { background: linear-gradient(90deg, rgba(224,230,235,0), var(--accent-muted) 25%, var(--accent-muted) 75%, rgba(224,230,235,0)); }
|
| 633 |
+
/* Inline reasoning (removed toggle to prevent layout shift on click) */
|
| 634 |
+
.reasoning-inline { display:inline; font-size:var(--font-size); font-style:italic; color:#555; white-space:pre-wrap; margin-right:4px; cursor:pointer; position:relative; }
|
| 635 |
+
.reasoning-inline .reasoning-text { display:inline; }
|
| 636 |
+
.reasoning-inline .reasoning-icon { display:inline-block; margin-right:2px; }
|
| 637 |
+
.reasoning-inline.collapsed .reasoning-text { display:none; }
|
| 638 |
+
.reasoning-inline.collapsed::after { content:'(...)'; font-style:italic; color:#777; margin-left:4px; }
|
| 639 |
+
.message-box .main-content { white-space:normal; }
|
| 640 |
+
/* tighten spacing */
|
| 641 |
+
.split-col .group-divider { margin:4px 0 2px 0; }
|
| 642 |
+
.toolbar {
|
| 643 |
+
display: flex;
|
| 644 |
+
align-items: center;
|
| 645 |
+
gap: 8px;
|
| 646 |
+
margin-bottom: 0;
|
| 647 |
+
font-size: var(--font-size);
|
| 648 |
+
max-height: 0;
|
| 649 |
+
overflow: hidden;
|
| 650 |
+
opacity: 0;
|
| 651 |
+
pointer-events: none;
|
| 652 |
+
transition: max-height 0.2s ease, opacity 0.2s ease;
|
| 653 |
+
flex-wrap: wrap;
|
| 654 |
+
}
|
| 655 |
+
.toolbar-wrap { position: sticky; top: 0; z-index: 10; background: var(--bg); }
|
| 656 |
+
.toolbar-hotzone { height: 6px; }
|
| 657 |
+
.toolbar-wrap:hover .toolbar { max-height: 500px; opacity: 1; pointer-events: auto; margin-bottom: 12px; }
|
| 658 |
+
.toolbar * { pointer-events: auto !important; }
|
| 659 |
+
.toolbar input,
|
| 660 |
+
.toolbar select { z-index: 100 !important; position: relative; }
|
| 661 |
+
.toolbar input[type="number"],
|
| 662 |
+
.toolbar input[type="text"],
|
| 663 |
+
.toolbar select {
|
| 664 |
+
width: 72px;
|
| 665 |
+
padding: 2px 6px;
|
| 666 |
+
border: 1px solid var(--accent-muted);
|
| 667 |
+
border-radius: var(--corner-radius);
|
| 668 |
+
background: var(--bg);
|
| 669 |
+
user-select: text !important;
|
| 670 |
+
-webkit-user-select: text !important;
|
| 671 |
+
-moz-user-select: text !important;
|
| 672 |
+
-ms-user-select: text !important;
|
| 673 |
+
pointer-events: auto !important;
|
| 674 |
+
cursor: pointer !important;
|
| 675 |
+
}
|
| 676 |
+
.toolbar input[type="text"] {
|
| 677 |
+
cursor: text !important;
|
| 678 |
+
}
|
| 679 |
+
.toolbar input[type="text"]:focus,
|
| 680 |
+
.toolbar input[type="number"]:focus,
|
| 681 |
+
.toolbar select:focus {
|
| 682 |
+
outline: 2px solid #0066cc;
|
| 683 |
+
outline-offset: 1px;
|
| 684 |
+
}
|
| 685 |
+
.toolbar button {
|
| 686 |
+
padding: 4px 8px;
|
| 687 |
+
border: 1px solid var(--accent-muted);
|
| 688 |
+
background: var(--panel-bg);
|
| 689 |
+
border-radius: var(--corner-radius);
|
| 690 |
+
cursor: pointer;
|
| 691 |
+
}
|
| 692 |
+
.chat-turn {
|
| 693 |
+
display: inline; /* inline like text */
|
| 694 |
+
background: transparent;
|
| 695 |
+
position: relative;
|
| 696 |
+
cursor: pointer;
|
| 697 |
+
}
|
| 698 |
+
/* No agent-specific background distinctions */
|
| 699 |
+
.turn-content {
|
| 700 |
+
white-space: normal;
|
| 701 |
+
color: var(--text);
|
| 702 |
+
font-size: var(--font-size);
|
| 703 |
+
display: inline; /* inline flow */
|
| 704 |
+
}
|
| 705 |
+
.chat-turn .agent-badge { margin-right: 0; vertical-align: baseline; }
|
| 706 |
+
.agent-badge {
|
| 707 |
+
display: inline;
|
| 708 |
+
position: relative;
|
| 709 |
+
border: var(--border-width) solid var(--accent-muted); /* slightly thicker */
|
| 710 |
+
border-radius: var(--pill-radius-left); /* round left and bottom-right */
|
| 711 |
+
font-size: var(--font-size);
|
| 712 |
+
color: var(--muted-text);
|
| 713 |
+
background: var(--panel-bg);
|
| 714 |
+
box-shadow: var(--inset-shadow);
|
| 715 |
+
line-height: 1.2;
|
| 716 |
+
border-right: 0;
|
| 717 |
+
}
|
| 718 |
+
/* Use flex on assistant badges to vertically center reward pill */
|
| 719 |
+
.chat-turn.role-assistant .agent-badge { display: inline-flex; align-items: center; }
|
| 720 |
+
.agent-badge::after {
|
| 721 |
+
content: none;
|
| 722 |
+
}
|
| 723 |
+
/* removed external separator; emoji is rendered inside message bubble */
|
| 724 |
+
.agent-name { font-weight: 700; }
|
| 725 |
+
.emoji-bw { filter: grayscale(100%); opacity: 0.95; font-size: var(--font-size); vertical-align: baseline; margin: 0; position: relative; top: -1px; line-height: 1; display: inline-block; }
|
| 726 |
+
.ts-badge {
|
| 727 |
+
position: relative;
|
| 728 |
+
display: inline;
|
| 729 |
+
border: var(--border-width) solid var(--accent-muted-2); /* slightly thicker */
|
| 730 |
+
border-radius: var(--corner-radius); /* not a pill */
|
| 731 |
+
font-size: var(--font-size);
|
| 732 |
+
# font-weight: 700;
|
| 733 |
+
color: var(--muted-text);
|
| 734 |
+
background: #F4F8FB; /* subtle tint */
|
| 735 |
+
# padding: 1px 6px; /* slight padding for visibility */
|
| 736 |
+
margin-right: 8px; /* small gap from following content */
|
| 737 |
+
pointer-events: auto; /* allow events so we can ignore them in JS */
|
| 738 |
+
}
|
| 739 |
+
/* Hide timestep badges when grouping by 1 */
|
| 740 |
+
.hide-ts-badges .ts-badge { display: none; }
|
| 741 |
+
/* Strong hide: completely hide collapsed turns */
|
| 742 |
+
.strong-hide .chat-turn.collapsed { display: none; }
|
| 743 |
+
.ts-badge::before {
|
| 744 |
+
content: "";
|
| 745 |
+
position: relative;
|
| 746 |
+
background: var(--accent-muted-2);
|
| 747 |
+
border-radius: 2px;
|
| 748 |
+
}
|
| 749 |
+
.agent-badge { margin-left: 6px; }
|
| 750 |
+
.message-box {
|
| 751 |
+
display: inline; /* inline bubble behaving like text */
|
| 752 |
+
font-size: var(--font-size);
|
| 753 |
+
border: var(--border-width) solid var(--accent-muted);
|
| 754 |
+
border-radius: var(--pill-radius-right); /* round left and bottom-right */
|
| 755 |
+
position: relative;
|
| 756 |
+
background: var(--bg);
|
| 757 |
+
vertical-align: baseline;
|
| 758 |
+
line-height: 1.2;
|
| 759 |
+
padding-left: 0;
|
| 760 |
+
border-left: 0;
|
| 761 |
+
}
|
| 762 |
+
.chat-turn.agent-alice.role-assistant .message-box::before { color: #0eb224; }
|
| 763 |
+
.chat-turn.agent-bob.role-assistant .message-box::before { color: #ef8323; }
|
| 764 |
+
.chat-turn.collapsed .message-box::before { display: none; }
|
| 765 |
+
/* Assistant bubble border colors by common agent names */
|
| 766 |
+
.chat-turn.agent-alice.role-assistant .message-box { border-color: #0eb224; }
|
| 767 |
+
.chat-turn.agent-bob.role-assistant .message-box { border-color: #ef8323; }
|
| 768 |
+
/* Tie badge and seam to agent color for a cohesive capsule, assistants only */
|
| 769 |
+
.chat-turn.agent-alice.role-assistant .agent-badge { border-color: #0eb224; background: rgba(14,178,36,0.08); }
|
| 770 |
+
.chat-turn.agent-alice.role-assistant .agent-badge::after { border-right-color: #0eb224; }
|
| 771 |
+
.chat-turn.agent-alice.role-assistant .turn-content::before { border-left-color: #0eb224; border-top-color: #0eb224; }
|
| 772 |
+
.chat-turn.agent-alice.role-assistant .message-box { border-color: #0eb224; }
|
| 773 |
+
|
| 774 |
+
.chat-turn.agent-bob.role-assistant .agent-badge { border-color: #ef8323; background: rgba(239,131,35,0.10); }
|
| 775 |
+
.chat-turn.agent-bob.role-assistant .agent-badge::after { border-right-color: #ef8323; }
|
| 776 |
+
.chat-turn.agent-bob.role-assistant .turn-content::before { border-left-color: #ef8323; border-top-color: #ef8323; }
|
| 777 |
+
.chat-turn.agent-bob.role-assistant .message-box { border-color: #ef8323; }
|
| 778 |
+
/* No colored agent-name; keep neutral */
|
| 779 |
+
.reward {
|
| 780 |
+
display: inline-flex;
|
| 781 |
+
align-items: center;
|
| 782 |
+
justify-content: center;
|
| 783 |
+
background: linear-gradient(90deg, #fffdf2 0%, #ffffff 75%);
|
| 784 |
+
color: #000000; /* full black */
|
| 785 |
+
font-weight: 600; /* slightly bolder */
|
| 786 |
+
font-family: "Inter", ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Oxygen, Ubuntu, Cantarell, "Fira Sans", "Droid Sans", "Helvetica Neue", Arial, "Noto Sans", sans-serif;
|
| 787 |
+
font-size: var(--font-size);
|
| 788 |
+
letter-spacing: 0.15px;
|
| 789 |
+
line-height: 1;
|
| 790 |
+
padding: 0 4px 1px 4px; /* slight bottom pad for optical centering */
|
| 791 |
+
border-radius: 4px;
|
| 792 |
+
border: 1px solid #f4e6a8;
|
| 793 |
+
margin: 0 4px;
|
| 794 |
+
box-shadow: 0 0 0 1px rgba(255,255,255,0.55) inset, 0 1px 2px rgba(0,0,0,0.04);
|
| 795 |
+
}
|
| 796 |
+
.message-placeholder { display: none; color: #7f8c8d; font-style: italic; }
|
| 797 |
+
.chat-turn.collapsed .message-box { color: transparent; font-size: 0; display: inline-block; }
|
| 798 |
+
.chat-turn.collapsed .message-box::after { content: "(...)"; color: #7f8c8d; font-style: italic; font-size: var(--font-size); line-height: 1.2; }
|
| 799 |
+
.chat-turn.collapsed .agent-badge,
|
| 800 |
+
.chat-turn.collapsed .message-box { opacity: 0.3; }
|
| 801 |
+
/* Group divider - clearer and pretty */
|
| 802 |
+
.group-divider {
|
| 803 |
+
display: flex;
|
| 804 |
+
align-items: center;
|
| 805 |
+
gap: 8px;
|
| 806 |
+
width: 100%;
|
| 807 |
+
margin: 8px 0 4px 0;
|
| 808 |
+
position: relative;
|
| 809 |
+
cursor: pointer;
|
| 810 |
+
user-select: none;
|
| 811 |
+
}
|
| 812 |
+
.group-divider::before,
|
| 813 |
+
.group-divider::after {
|
| 814 |
+
content: "";
|
| 815 |
+
flex: 1 1 auto;
|
| 816 |
+
height: 2px;
|
| 817 |
+
background: linear-gradient(90deg, rgba(224,230,235,0), var(--accent-muted-2) 30%, var(--accent-muted-2) 70%, rgba(224,230,235,0));
|
| 818 |
+
}
|
| 819 |
+
.group-divider .group-label {
|
| 820 |
+
display: inline-block;
|
| 821 |
+
border: 1px solid var(--accent-muted);
|
| 822 |
+
border-radius: 999px;
|
| 823 |
+
padding: 2px 10px;
|
| 824 |
+
font-size: var(--group-label-font-size);
|
| 825 |
+
font-weight: 700;
|
| 826 |
+
color: var(--muted-text);
|
| 827 |
+
background: var(--bg);
|
| 828 |
+
box-shadow: var(--inset-shadow);
|
| 829 |
+
position: relative;
|
| 830 |
+
z-index: 1;
|
| 831 |
+
transition: background 0.2s ease;
|
| 832 |
+
}
|
| 833 |
+
|
| 834 |
+
.group-divider:hover .group-label {
|
| 835 |
+
background: var(--panel-bg);
|
| 836 |
+
}
|
| 837 |
+
|
| 838 |
+
.group-label::before {
|
| 839 |
+
content: '▼ ';
|
| 840 |
+
font-size: 0.8em;
|
| 841 |
+
display: inline-block;
|
| 842 |
+
transition: transform 0.2s ease;
|
| 843 |
+
opacity: 0;
|
| 844 |
+
}
|
| 845 |
+
|
| 846 |
+
.group-divider:hover .group-label::before {
|
| 847 |
+
opacity: 1;
|
| 848 |
+
}
|
| 849 |
+
|
| 850 |
+
.group-divider.collapsed .group-label::before {
|
| 851 |
+
content: '▶ ';
|
| 852 |
+
opacity: 1;
|
| 853 |
+
}
|
| 854 |
+
|
| 855 |
+
/* Hide collapsed rounds in strong hide mode */
|
| 856 |
+
.strong-hide .group-divider.collapsed {
|
| 857 |
+
display: none !important;
|
| 858 |
+
}
|
| 859 |
+
/* Enhance contrast for print / export */
|
| 860 |
+
body.split-mode .group-divider::before,
|
| 861 |
+
body.split-mode .group-divider::after {
|
| 862 |
+
background: linear-gradient(90deg, rgba(224,230,235,0), var(--accent-muted) 25%, var(--accent-muted) 75%, rgba(224,230,235,0));
|
| 863 |
+
}
|
| 864 |
+
.chat-turn .turn-content { position: relative; }
|
| 865 |
+
.chat-turn .turn-content::before {
|
| 866 |
+
content: none;
|
| 867 |
+
}
|
| 868 |
+
.chat-turn .agent-badge {
|
| 869 |
+
position: relative;
|
| 870 |
+
}
|
| 871 |
+
/* removed absolute-positioned emoji to prevent overlap */
|
| 872 |
+
</style>
|
| 873 |
+
"""
|
| 874 |
+
|
| 875 |
+
# HTML structure
|
| 876 |
+
html_parts = [
|
| 877 |
+
"<!DOCTYPE html>",
|
| 878 |
+
"<html>",
|
| 879 |
+
"<head>",
|
| 880 |
+
"<meta charset='UTF-8'>",
|
| 881 |
+
"<title>Chat Turns</title>",
|
| 882 |
+
css,
|
| 883 |
+
"<script>\n"
|
| 884 |
+
"document.addEventListener('DOMContentLoaded', function() {\n"
|
| 885 |
+
" const linearFlow = document.getElementById('flow-linear');\n"
|
| 886 |
+
" const splitFlow = document.getElementById('flow-split');\n"
|
| 887 |
+
" const chatFlow = document.getElementById('flow-chat');\n"
|
| 888 |
+
" let splitViewOn = false;\n"
|
| 889 |
+
" let chatViewOn = true;\n"
|
| 890 |
+
" function activeFlows() { return [chatViewOn && chatFlow ? chatFlow : null, splitViewOn && splitFlow ? splitFlow : null, linearFlow].filter(Boolean).filter(f => f.style.display !== 'none'); }\n"
|
| 891 |
+
" // State for range filtering and strong hide\n"
|
| 892 |
+
" let currentRangeStart = null;\n"
|
| 893 |
+
" let currentRangeEnd = null;\n"
|
| 894 |
+
" let strongHideOn = false;\n"
|
| 895 |
+
" document.body.addEventListener('click', function(e){\n"
|
| 896 |
+
" if (e.target.closest('input, textarea, select, button, .round-context-edit, .toolbar')) { return; }\n"
|
| 897 |
+
" if (e.target.closest('.ts-badge')) { return; }\n"
|
| 898 |
+
" const r = e.target.closest('.reasoning-inline'); if (r) { e.stopPropagation(); r.classList.toggle('collapsed'); return; }\n"
|
| 899 |
+
" const turn = e.target.closest('.chat-turn');\n"
|
| 900 |
+
" if (turn) { e.stopPropagation(); turn.classList.toggle('collapsed'); }\n"
|
| 901 |
+
" });\n"
|
| 902 |
+
" // Reasoning handled via <details>, no JS required\n"
|
| 903 |
+
" function applyRangeFilter() {\n"
|
| 904 |
+
" for (const flow of activeFlows()) {\n"
|
| 905 |
+
" const turns = Array.from(flow.querySelectorAll('.chat-turn'));\n"
|
| 906 |
+
" for (const el of turns) {\n"
|
| 907 |
+
" const t = parseInt(el.getAttribute('data-time-step') || '0', 10);\n"
|
| 908 |
+
" const afterStart = (currentRangeStart === null) || (t >= currentRangeStart);\n"
|
| 909 |
+
" const beforeEnd = (currentRangeEnd === null) || (t <= currentRangeEnd);\n"
|
| 910 |
+
" el.style.display = (afterStart && beforeEnd) ? '' : 'none';\n"
|
| 911 |
+
" }\n"
|
| 912 |
+
" const dividers = Array.from(flow.querySelectorAll('.group-divider'));\n"
|
| 913 |
+
" for (const d of dividers) {\n"
|
| 914 |
+
" let anyVisible = false;\n"
|
| 915 |
+
" let el = d.nextElementSibling;\n"
|
| 916 |
+
" while (el && !el.classList.contains('group-divider')) {\n"
|
| 917 |
+
" if (el.classList.contains('chat-turn')) {\n"
|
| 918 |
+
" const disp = getComputedStyle(el).display;\n"
|
| 919 |
+
" if (disp !== 'none') { anyVisible = true; break; }\n"
|
| 920 |
+
" } else if (el.classList.contains('split-wrapper')) {\n"
|
| 921 |
+
" // Search descendants for any visible chat-turn\n"
|
| 922 |
+
" const turns = Array.from(el.querySelectorAll('.chat-turn'));\n"
|
| 923 |
+
" for (const tEl of turns) {\n"
|
| 924 |
+
" const disp2 = getComputedStyle(tEl).display;\n"
|
| 925 |
+
" if (disp2 !== 'none') { anyVisible = true; break; }\n"
|
| 926 |
+
" }\n"
|
| 927 |
+
" if (anyVisible) break;\n"
|
| 928 |
+
" }\n"
|
| 929 |
+
" el = el.nextElementSibling;\n"
|
| 930 |
+
" }\n"
|
| 931 |
+
" d.style.display = anyVisible ? '' : 'none';\n"
|
| 932 |
+
" }\n"
|
| 933 |
+
" }\n"
|
| 934 |
+
" }\n"
|
| 935 |
+
" function applyGrouping(n) {\n"
|
| 936 |
+
" function groupContainer(container, n) {\n"
|
| 937 |
+
" Array.from(container.querySelectorAll(':scope > .group-divider')).forEach(el => el.remove());\n"
|
| 938 |
+
" if (!n || n <= 0) { return; }\n"
|
| 939 |
+
" const turns = Array.from(container.querySelectorAll(':scope > .chat-turn'));\n"
|
| 940 |
+
" if (turns.length === 0) return;\n"
|
| 941 |
+
" const items = Array.from(container.children).filter(el => !el.classList.contains('group-divider'));\n"
|
| 942 |
+
" const frag = document.createDocumentFragment();\n"
|
| 943 |
+
" let lastGroup = -1;\n"
|
| 944 |
+
" for (const el of items) {\n"
|
| 945 |
+
" if (!el.classList.contains('chat-turn')) { frag.appendChild(el); continue; }\n"
|
| 946 |
+
" const t = parseInt(el.getAttribute('data-time-step') || '0', 10);\n"
|
| 947 |
+
" const g = Math.floor(t / n);\n"
|
| 948 |
+
" if (g !== lastGroup) {\n"
|
| 949 |
+
" const div = document.createElement('div');\n"
|
| 950 |
+
" div.className = 'group-divider';\n"
|
| 951 |
+
" const label = document.createElement('span');\n"
|
| 952 |
+
" label.className = 'group-label';\n"
|
| 953 |
+
" const roundIndex = g + 1;\n"
|
| 954 |
+
" label.textContent = `Round ${roundIndex}`;\n"
|
| 955 |
+
" div.appendChild(label);\n"
|
| 956 |
+
" frag.appendChild(div);\n"
|
| 957 |
+
" lastGroup = g;\n"
|
| 958 |
+
" }\n"
|
| 959 |
+
" frag.appendChild(el);\n"
|
| 960 |
+
" }\n"
|
| 961 |
+
" container.innerHTML = '';\n"
|
| 962 |
+
" container.appendChild(frag);\n"
|
| 963 |
+
" container.classList.toggle('hide-ts-badges', n === 1);\n"
|
| 964 |
+
" container.classList.toggle('strong-hide', strongHideOn);\n"
|
| 965 |
+
" }\n"
|
| 966 |
+
" for (const flow of activeFlows()) {\n"
|
| 967 |
+
" if (flow.id === 'flow-split') {\n"
|
| 968 |
+
" // Snapshot original turns once to avoid drift on repeated grouping\n"
|
| 969 |
+
" const getOriginalTurns = () => {\n"
|
| 970 |
+
" if (!flow.dataset.origData) {\n"
|
| 971 |
+
" const data = [];\n"
|
| 972 |
+
" const cols0 = flow.querySelectorAll('.split-col');\n"
|
| 973 |
+
" cols0.forEach(col => {\n"
|
| 974 |
+
" const agent = col.getAttribute('data-agent') || '';\n"
|
| 975 |
+
" col.querySelectorAll(':scope > .chat-turn').forEach(el => {\n"
|
| 976 |
+
" const t = parseInt(el.getAttribute('data-time-step')||'0',10);\n"
|
| 977 |
+
" data.push({agent, time:t, html: el.outerHTML});\n"
|
| 978 |
+
" });\n"
|
| 979 |
+
" });\n"
|
| 980 |
+
" flow.dataset.origData = JSON.stringify(data);\n"
|
| 981 |
+
" }\n"
|
| 982 |
+
" return JSON.parse(flow.dataset.origData);\n"
|
| 983 |
+
" };\n"
|
| 984 |
+
" const original = getOriginalTurns();\n"
|
| 985 |
+
" const agents = Array.from(new Set(original.map(o => o.agent))).sort();\n"
|
| 986 |
+
" const groups = new Map();\n"
|
| 987 |
+
" original.forEach(o => {\n"
|
| 988 |
+
" const g = n && n > 0 ? Math.floor(o.time / n) : 0;\n"
|
| 989 |
+
" if (!groups.has(g)) groups.set(g, new Map());\n"
|
| 990 |
+
" const gm = groups.get(g);\n"
|
| 991 |
+
" if (!gm.has(o.agent)) gm.set(o.agent, []);\n"
|
| 992 |
+
" gm.get(o.agent).push(o);\n"
|
| 993 |
+
" });\n"
|
| 994 |
+
" flow.innerHTML = '';\n"
|
| 995 |
+
" const sorted = Array.from(groups.keys()).sort((a,b)=>a-b);\n"
|
| 996 |
+
" sorted.forEach(g => {\n"
|
| 997 |
+
" const div = document.createElement('div');\n"
|
| 998 |
+
" div.className = 'group-divider';\n"
|
| 999 |
+
" const label = document.createElement('span');\n"
|
| 1000 |
+
" label.className = 'group-label';\n"
|
| 1001 |
+
" label.textContent = `Round ${g+1}`;\n"
|
| 1002 |
+
" div.appendChild(label);\n"
|
| 1003 |
+
" flow.appendChild(div);\n"
|
| 1004 |
+
" const wrapper = document.createElement('div');\n"
|
| 1005 |
+
" wrapper.className = 'split-wrapper';\n"
|
| 1006 |
+
" agents.forEach(agent => {\n"
|
| 1007 |
+
" const colDiv = document.createElement('div');\n"
|
| 1008 |
+
" colDiv.className = 'split-col';\n"
|
| 1009 |
+
" colDiv.setAttribute('data-agent', agent);\n"
|
| 1010 |
+
" (groups.get(g).get(agent) || []).forEach(o => { colDiv.insertAdjacentHTML('beforeend', o.html); });\n"
|
| 1011 |
+
" wrapper.appendChild(colDiv);\n"
|
| 1012 |
+
" });\n"
|
| 1013 |
+
" if (wrapper.children.length === 2) { const res = document.createElement('div'); res.className='split-resizer'; wrapper.insertBefore(res, wrapper.children[1]); }\n"
|
| 1014 |
+
" flow.appendChild(wrapper);\n"
|
| 1015 |
+
" });\n"
|
| 1016 |
+
" flow.classList.toggle('hide-ts-badges', n === 1);\n"
|
| 1017 |
+
" flow.classList.toggle('strong-hide', strongHideOn);\n"
|
| 1018 |
+
" document.body.classList.add('split-mode');\n"
|
| 1019 |
+
" } else {\n"
|
| 1020 |
+
" groupContainer(flow, n);\n"
|
| 1021 |
+
" }\n"
|
| 1022 |
+
" }\n"
|
| 1023 |
+
" applyRangeFilter();\n"
|
| 1024 |
+
" initSplitResizers();\n"
|
| 1025 |
+
" }\n"
|
| 1026 |
+
" function initSplitResizers() {\n"
|
| 1027 |
+
" const wrappers = document.querySelectorAll('#flow-split .split-wrapper');\n"
|
| 1028 |
+
" wrappers.forEach(wrap => {\n"
|
| 1029 |
+
" const resizer = wrap.querySelector('.split-resizer');\n"
|
| 1030 |
+
" if (!resizer || resizer.dataset.bound) return; resizer.dataset.bound='1';\n"
|
| 1031 |
+
" const cols = wrap.querySelectorAll('.split-col'); if (cols.length !== 2) return; const c0=cols[0], c1=cols[1];\n"
|
| 1032 |
+
" c0.style.flex=c1.style.flex='1 1 0'; c0.style.width=c1.style.width='';\n"
|
| 1033 |
+
" requestAnimationFrame(()=>{ const w0=c0.scrollWidth,w1=c1.scrollWidth,total=w0+w1||1; let p0=w0/total,p1=w1/total; const minP=0.25,maxP=0.75; if(p0<minP){p0=minP;p1=1-p0;} else if(p0>maxP){p0=maxP;p1=1-p0;} c0.style.flex='0 0 '+(p0*100).toFixed(2)+'%'; c1.style.flex='0 0 '+(p1*100).toFixed(2)+'%'; });\n"
|
| 1034 |
+
" let dragging=false,startX=0,startP0=0;\n"
|
| 1035 |
+
" const onDown=e=>{ dragging=true; startX=e.clientX; wrap.classList.add('resizing'); resizer.classList.add('dragging'); const rect=wrap.getBoundingClientRect(); const w=rect.width; const c0Rect=c0.getBoundingClientRect(); startP0=c0Rect.width/w; document.body.style.cursor='col-resize'; e.preventDefault(); };\n"
|
| 1036 |
+
" const onMove=e=>{ if(!dragging)return; const rect=wrap.getBoundingClientRect(); const w=rect.width; let delta=(e.clientX-startX)/w; let newP0=startP0+delta; const minP=0.15,maxP=0.85; if(newP0<minP)newP0=minP; if(newP0>maxP)newP0=maxP; c0.style.flex='0 0 '+(newP0*100).toFixed(2)+'%'; c1.style.flex='0 0 '+((1-newP0)*100).toFixed(2)+'%'; };\n"
|
| 1037 |
+
" const onUp=()=>{ if(!dragging)return; dragging=false; wrap.classList.remove('resizing'); resizer.classList.remove('dragging'); document.body.style.cursor=''; };\n"
|
| 1038 |
+
" resizer.addEventListener('mousedown', onDown); window.addEventListener('mousemove', onMove); window.addEventListener('mouseup', onUp);\n"
|
| 1039 |
+
" resizer.addEventListener('dblclick', e=>{ if(e.shiftKey){ c0.style.flex=c1.style.flex='1 1 0'; requestAnimationFrame(()=>{ const w0=c0.scrollWidth,w1=c1.scrollWidth,total=w0+w1||1; let p0=w0/total,p1=w1/total; const minP=0.25,maxP=0.75; if(p0<minP){p0=minP;p1=1-p0;} else if(p0>maxP){p0=maxP;p1=1-p0;} c0.style.flex='0 0 '+(p0*100).toFixed(2)+'%'; c1.style.flex='0 0 '+(p1*100).toFixed(2)+'%'; }); } else { c0.style.flex='0 0 50%'; c1.style.flex='0 0 50%'; } });\n"
|
| 1040 |
+
" });\n"
|
| 1041 |
+
" }\n"
|
| 1042 |
+
" initSplitResizers();\n"
|
| 1043 |
+
" const input = document.getElementById('group-size');\n"
|
| 1044 |
+
" const btn = document.getElementById('apply-grouping');\n"
|
| 1045 |
+
" if (btn && input) {\n"
|
| 1046 |
+
" btn.addEventListener('click', () => { const n = parseInt(input.value || '0', 10); applyGrouping(n); });\n"
|
| 1047 |
+
" input.addEventListener('keydown', (e) => { if (e.key === 'Enter') { const n = parseInt(input.value || '0', 10); applyGrouping(n); } });\n"
|
| 1048 |
+
" }\n"
|
| 1049 |
+
" if (input) { input.value = '1'; applyGrouping(1); }\n"
|
| 1050 |
+
" const rangeStart = document.getElementById('range-start');\n"
|
| 1051 |
+
" const rangeEnd = document.getElementById('range-end');\n"
|
| 1052 |
+
" const rangeBtn = document.getElementById('apply-range');\n"
|
| 1053 |
+
" if (rangeBtn && rangeStart && rangeEnd) {\n"
|
| 1054 |
+
" const applyRange = () => {\n"
|
| 1055 |
+
" const sv = parseInt(rangeStart.value || '', 10);\n"
|
| 1056 |
+
" const ev = parseInt(rangeEnd.value || '', 10);\n"
|
| 1057 |
+
" currentRangeStart = Number.isFinite(sv) ? sv : null;\n"
|
| 1058 |
+
" currentRangeEnd = Number.isFinite(ev) ? ev : null;\n"
|
| 1059 |
+
" applyRangeFilter();\n"
|
| 1060 |
+
" };\n"
|
| 1061 |
+
" rangeBtn.addEventListener('click', applyRange);\n"
|
| 1062 |
+
" rangeStart.addEventListener('keydown', (e) => { if (e.key === 'Enter') applyRange(); });\n"
|
| 1063 |
+
" rangeEnd.addEventListener('keydown', (e) => { if (e.key === 'Enter') applyRange(); });\n"
|
| 1064 |
+
" }\n"
|
| 1065 |
+
" const strongHideBtn = document.getElementById('toggle-strong-hide');\n"
|
| 1066 |
+
" const strongHideStateEl = document.getElementById('strong-hide-state');\n"
|
| 1067 |
+
" if (strongHideBtn) {\n"
|
| 1068 |
+
" const setLabel = () => { if (strongHideStateEl) { strongHideStateEl.textContent = strongHideOn ? 'On' : 'Off'; } };\n"
|
| 1069 |
+
" strongHideBtn.addEventListener('click', () => { strongHideOn = !strongHideOn; for (const f of activeFlows()) { f.classList.toggle('strong-hide', strongHideOn); } setLabel(); });\n"
|
| 1070 |
+
" if (strongHideOn) { for (const f of activeFlows()) { f.classList.add('strong-hide'); } }\n"
|
| 1071 |
+
" setLabel();\n"
|
| 1072 |
+
" }\n"
|
| 1073 |
+
" const splitBtn = document.getElementById('toggle-split-view');\n"
|
| 1074 |
+
" const splitStateEl = document.getElementById('split-view-state');\n"
|
| 1075 |
+
" if (splitBtn && splitFlow && linearFlow) {\n"
|
| 1076 |
+
" const updateSplit = () => { if (splitStateEl) splitStateEl.textContent = splitViewOn ? 'On' : 'Off'; };\n"
|
| 1077 |
+
" splitBtn.addEventListener('click', () => { if (chatViewOn) return; splitViewOn = !splitViewOn; linearFlow.style.display = splitViewOn ? 'none' : ''; splitFlow.style.display = splitViewOn ? '' : 'none'; applyGrouping(parseInt(input.value||'1',10)); updateSplit(); });\n"
|
| 1078 |
+
" updateSplit();\n"
|
| 1079 |
+
" }\n"
|
| 1080 |
+
" const chatBtn = document.getElementById('toggle-chat-view');\n"
|
| 1081 |
+
" const chatStateEl = document.getElementById('chat-view-state');\n"
|
| 1082 |
+
" const hideUserBtn = document.getElementById('toggle-hide-user-messages');\n"
|
| 1083 |
+
" const hideUserStateEl = document.getElementById('hide-user-state');\n"
|
| 1084 |
+
" const widthControl = document.getElementById('chat-width-control');\n"
|
| 1085 |
+
" const widthSlider = document.getElementById('chat-width-slider');\n"
|
| 1086 |
+
" const widthValue = document.getElementById('chat-width-value');\n"
|
| 1087 |
+
" let hideUserMessages = false;\n"
|
| 1088 |
+
" if (chatBtn && chatFlow && linearFlow) {\n"
|
| 1089 |
+
" const updateChat = () => {\n"
|
| 1090 |
+
" if (chatStateEl) chatStateEl.textContent = chatViewOn ? 'On' : 'Off';\n"
|
| 1091 |
+
" if (hideUserBtn) hideUserBtn.style.display = chatViewOn ? '' : 'none';\n"
|
| 1092 |
+
" if (widthControl) widthControl.style.display = chatViewOn ? '' : 'none';\n"
|
| 1093 |
+
" };\n"
|
| 1094 |
+
" chatBtn.addEventListener('click', () => {\n"
|
| 1095 |
+
" chatViewOn = !chatViewOn;\n"
|
| 1096 |
+
" if (chatViewOn) {\n"
|
| 1097 |
+
" splitViewOn = false;\n"
|
| 1098 |
+
" linearFlow.style.display = 'none';\n"
|
| 1099 |
+
" if (splitFlow) splitFlow.style.display = 'none';\n"
|
| 1100 |
+
" chatFlow.style.display = '';\n"
|
| 1101 |
+
" if (splitStateEl) splitStateEl.textContent = 'Off';\n"
|
| 1102 |
+
" } else {\n"
|
| 1103 |
+
" chatFlow.style.display = 'none';\n"
|
| 1104 |
+
" linearFlow.style.display = '';\n"
|
| 1105 |
+
" }\n"
|
| 1106 |
+
" updateChat();\n"
|
| 1107 |
+
" });\n"
|
| 1108 |
+
" updateChat();\n"
|
| 1109 |
+
" }\n"
|
| 1110 |
+
" if (hideUserBtn && hideUserStateEl && chatFlow) {\n"
|
| 1111 |
+
" const updateHideUser = () => { hideUserStateEl.textContent = hideUserMessages ? 'On' : 'Off'; };\n"
|
| 1112 |
+
" hideUserBtn.addEventListener('click', () => {\n"
|
| 1113 |
+
" hideUserMessages = !hideUserMessages;\n"
|
| 1114 |
+
" chatFlow.classList.toggle('hide-user-messages', hideUserMessages);\n"
|
| 1115 |
+
" updateHideUser();\n"
|
| 1116 |
+
" });\n"
|
| 1117 |
+
" updateHideUser();\n"
|
| 1118 |
+
" }\n"
|
| 1119 |
+
" if (widthSlider && widthValue && chatFlow) {\n"
|
| 1120 |
+
" const savedWidth = localStorage.getItem('chat-view-width');\n"
|
| 1121 |
+
" if (savedWidth) {\n"
|
| 1122 |
+
" widthSlider.value = savedWidth;\n"
|
| 1123 |
+
" chatFlow.style.setProperty('--chat-width', savedWidth + 'px');\n"
|
| 1124 |
+
" widthValue.textContent = savedWidth + 'px';\n"
|
| 1125 |
+
" }\n"
|
| 1126 |
+
" widthSlider.addEventListener('input', (e) => {\n"
|
| 1127 |
+
" const width = e.target.value;\n"
|
| 1128 |
+
" chatFlow.style.setProperty('--chat-width', width + 'px');\n"
|
| 1129 |
+
" widthValue.textContent = width + 'px';\n"
|
| 1130 |
+
" localStorage.setItem('chat-view-width', width);\n"
|
| 1131 |
+
" });\n"
|
| 1132 |
+
" }\n"
|
| 1133 |
+
" const fontFamilySelect = document.getElementById('font-family-select');\n"
|
| 1134 |
+
" const fontSizeInput = document.getElementById('font-size-input');\n"
|
| 1135 |
+
" if (fontFamilySelect) {\n"
|
| 1136 |
+
" const savedFont = localStorage.getItem('render-font-family');\n"
|
| 1137 |
+
" if (savedFont) {\n"
|
| 1138 |
+
" fontFamilySelect.value = savedFont;\n"
|
| 1139 |
+
" document.body.style.setProperty('--font-family', savedFont);\n"
|
| 1140 |
+
" }\n"
|
| 1141 |
+
" fontFamilySelect.addEventListener('change', (e) => {\n"
|
| 1142 |
+
" const font = e.target.value;\n"
|
| 1143 |
+
" document.body.style.setProperty('--font-family', font);\n"
|
| 1144 |
+
" localStorage.setItem('render-font-family', font);\n"
|
| 1145 |
+
" });\n"
|
| 1146 |
+
" }\n"
|
| 1147 |
+
" if (fontSizeInput) {\n"
|
| 1148 |
+
" const savedSize = localStorage.getItem('render-font-size');\n"
|
| 1149 |
+
" if (savedSize) {\n"
|
| 1150 |
+
" fontSizeInput.value = savedSize;\n"
|
| 1151 |
+
" document.body.style.setProperty('--font-size', savedSize + 'px');\n"
|
| 1152 |
+
" }\n"
|
| 1153 |
+
" fontSizeInput.addEventListener('input', (e) => {\n"
|
| 1154 |
+
" const size = e.target.value;\n"
|
| 1155 |
+
" document.body.style.setProperty('--font-size', size + 'px');\n"
|
| 1156 |
+
" localStorage.setItem('render-font-size', size);\n"
|
| 1157 |
+
" });\n"
|
| 1158 |
+
" }\n"
|
| 1159 |
+
" const aliceEmojiInput = document.getElementById('alice-emoji-input');\n"
|
| 1160 |
+
" const aliceNameInput = document.getElementById('alice-name-input');\n"
|
| 1161 |
+
" const bobEmojiInput = document.getElementById('bob-emoji-input');\n"
|
| 1162 |
+
" const bobNameInput = document.getElementById('bob-name-input');\n"
|
| 1163 |
+
" const applyAgentNamesBtn = document.getElementById('apply-agent-names');\n"
|
| 1164 |
+
" function loadAgentNames() {\n"
|
| 1165 |
+
" if (aliceEmojiInput && aliceNameInput && bobEmojiInput && bobNameInput) {\n"
|
| 1166 |
+
" const savedAliceEmoji = localStorage.getItem('alice-emoji') || '🤖';\n"
|
| 1167 |
+
" const savedAliceName = localStorage.getItem('alice-name') || 'Alice';\n"
|
| 1168 |
+
" const savedBobEmoji = localStorage.getItem('bob-emoji') || '🤖';\n"
|
| 1169 |
+
" const savedBobName = localStorage.getItem('bob-name') || 'Bob';\n"
|
| 1170 |
+
" aliceEmojiInput.value = savedAliceEmoji;\n"
|
| 1171 |
+
" aliceNameInput.value = savedAliceName;\n"
|
| 1172 |
+
" bobEmojiInput.value = savedBobEmoji;\n"
|
| 1173 |
+
" bobNameInput.value = savedBobName;\n"
|
| 1174 |
+
" applyAgentNamesToDOM(savedAliceEmoji, savedAliceName, savedBobEmoji, savedBobName);\n"
|
| 1175 |
+
" }\n"
|
| 1176 |
+
" }\n"
|
| 1177 |
+
" function applyAgentNamesToDOM(aliceEmoji, aliceName, bobEmoji, bobName) {\n"
|
| 1178 |
+
" const agentMap = { 'alice': { name: aliceName, emoji: aliceEmoji }, 'bob': { name: bobName, emoji: bobEmoji } };\n"
|
| 1179 |
+
" document.querySelectorAll('[data-agent-id]').forEach(el => {\n"
|
| 1180 |
+
" const agentId = el.getAttribute('data-agent-id');\n"
|
| 1181 |
+
" if (!agentMap[agentId]) return;\n"
|
| 1182 |
+
" if (el.classList.contains('agent-name')) {\n"
|
| 1183 |
+
" el.textContent = agentMap[agentId].name;\n"
|
| 1184 |
+
" } else if (el.classList.contains('emoji-bw')) {\n"
|
| 1185 |
+
" const currentEmoji = el.textContent.trim();\n"
|
| 1186 |
+
" if (currentEmoji === '🤖' || currentEmoji === '👤') {\n"
|
| 1187 |
+
" el.textContent = agentMap[agentId].emoji;\n"
|
| 1188 |
+
" }\n"
|
| 1189 |
+
" }\n"
|
| 1190 |
+
" });\n"
|
| 1191 |
+
" const style = document.createElement('style');\n"
|
| 1192 |
+
" style.id = 'dynamic-agent-names-style';\n"
|
| 1193 |
+
" const existingStyle = document.getElementById('dynamic-agent-names-style');\n"
|
| 1194 |
+
" if (existingStyle) existingStyle.remove();\n"
|
| 1195 |
+
" style.textContent = `\n"
|
| 1196 |
+
" .agent-context-box.agent-alice .round-context-edit::before {\n"
|
| 1197 |
+
" content: '${aliceName} Prompt Summary:';\n"
|
| 1198 |
+
" }\n"
|
| 1199 |
+
" .agent-context-box.agent-bob .round-context-edit::before {\n"
|
| 1200 |
+
" content: '${bobName} Prompt Summary:';\n"
|
| 1201 |
+
" }\n"
|
| 1202 |
+
" `;\n"
|
| 1203 |
+
" document.head.appendChild(style);\n"
|
| 1204 |
+
" }\n"
|
| 1205 |
+
" if (applyAgentNamesBtn && aliceEmojiInput && aliceNameInput && bobEmojiInput && bobNameInput) {\n"
|
| 1206 |
+
" [aliceEmojiInput, aliceNameInput, bobEmojiInput, bobNameInput].forEach(input => {\n"
|
| 1207 |
+
" input.style.pointerEvents = 'auto';\n"
|
| 1208 |
+
" if (input.tagName === 'INPUT') {\n"
|
| 1209 |
+
" input.style.userSelect = 'text';\n"
|
| 1210 |
+
" input.style.webkitUserSelect = 'text';\n"
|
| 1211 |
+
" input.readOnly = false;\n"
|
| 1212 |
+
" }\n"
|
| 1213 |
+
" input.disabled = false;\n"
|
| 1214 |
+
" const stopAll = (e) => { e.stopPropagation(); e.stopImmediatePropagation(); };\n"
|
| 1215 |
+
" input.addEventListener('mousedown', stopAll, true);\n"
|
| 1216 |
+
" input.addEventListener('mouseup', stopAll, true);\n"
|
| 1217 |
+
" input.addEventListener('click', stopAll, true);\n"
|
| 1218 |
+
" input.addEventListener('dblclick', stopAll, true);\n"
|
| 1219 |
+
" input.addEventListener('focus', stopAll, true);\n"
|
| 1220 |
+
" input.addEventListener('blur', stopAll, true);\n"
|
| 1221 |
+
" input.addEventListener('paste', stopAll, true);\n"
|
| 1222 |
+
" input.addEventListener('cut', stopAll, true);\n"
|
| 1223 |
+
" input.addEventListener('copy', stopAll, true);\n"
|
| 1224 |
+
" input.addEventListener('select', stopAll, true);\n"
|
| 1225 |
+
" input.addEventListener('selectstart', stopAll, true);\n"
|
| 1226 |
+
" input.addEventListener('keydown', stopAll, true);\n"
|
| 1227 |
+
" input.addEventListener('keyup', stopAll, true);\n"
|
| 1228 |
+
" input.addEventListener('keypress', stopAll, true);\n"
|
| 1229 |
+
" input.addEventListener('input', stopAll, true);\n"
|
| 1230 |
+
" input.addEventListener('change', stopAll, true);\n"
|
| 1231 |
+
" input.addEventListener('contextmenu', stopAll, true);\n"
|
| 1232 |
+
" });\n"
|
| 1233 |
+
" const applyNames = () => {\n"
|
| 1234 |
+
" const aliceEmoji = aliceEmojiInput.value || '🤖';\n"
|
| 1235 |
+
" const aliceName = aliceNameInput.value.trim() || 'Alice';\n"
|
| 1236 |
+
" const bobEmoji = bobEmojiInput.value || '🤖';\n"
|
| 1237 |
+
" const bobName = bobNameInput.value.trim() || 'Bob';\n"
|
| 1238 |
+
" localStorage.setItem('alice-emoji', aliceEmoji);\n"
|
| 1239 |
+
" localStorage.setItem('alice-name', aliceName);\n"
|
| 1240 |
+
" localStorage.setItem('bob-emoji', bobEmoji);\n"
|
| 1241 |
+
" localStorage.setItem('bob-name', bobName);\n"
|
| 1242 |
+
" applyAgentNamesToDOM(aliceEmoji, aliceName, bobEmoji, bobName);\n"
|
| 1243 |
+
" };\n"
|
| 1244 |
+
" applyAgentNamesBtn.addEventListener('click', applyNames);\n"
|
| 1245 |
+
" [aliceNameInput, bobNameInput].forEach(input => {\n"
|
| 1246 |
+
" input.addEventListener('keydown', (e) => {\n"
|
| 1247 |
+
" if (e.key === 'Enter') {\n"
|
| 1248 |
+
" e.preventDefault();\n"
|
| 1249 |
+
" e.stopPropagation();\n"
|
| 1250 |
+
" e.stopImmediatePropagation();\n"
|
| 1251 |
+
" applyNames();\n"
|
| 1252 |
+
" }\n"
|
| 1253 |
+
" }, true);\n"
|
| 1254 |
+
" });\n"
|
| 1255 |
+
" [aliceEmojiInput, bobEmojiInput].forEach(select => {\n"
|
| 1256 |
+
" select.addEventListener('change', applyNames);\n"
|
| 1257 |
+
" });\n"
|
| 1258 |
+
" }\n"
|
| 1259 |
+
" loadAgentNames();\n"
|
| 1260 |
+
" function setupRoundCollapse() {\n"
|
| 1261 |
+
" document.addEventListener('click', function(e) {\n"
|
| 1262 |
+
" if (e.target.closest('input, textarea, select, button, .round-context-edit, .toolbar')) { return; }\n"
|
| 1263 |
+
" const divider = e.target.closest('.chat-group-divider, .group-divider');\n"
|
| 1264 |
+
" if (!divider) return;\n"
|
| 1265 |
+
" divider.classList.toggle('collapsed');\n"
|
| 1266 |
+
" const isCollapsed = divider.classList.contains('collapsed');\n"
|
| 1267 |
+
" let nextElement = divider.nextElementSibling;\n"
|
| 1268 |
+
" while (nextElement) {\n"
|
| 1269 |
+
" if (nextElement.classList.contains('chat-group-divider') || nextElement.classList.contains('group-divider')) {\n"
|
| 1270 |
+
" break;\n"
|
| 1271 |
+
" }\n"
|
| 1272 |
+
" if (isCollapsed) {\n"
|
| 1273 |
+
" if (!nextElement.dataset.originalDisplay) {\n"
|
| 1274 |
+
" nextElement.dataset.originalDisplay = nextElement.style.display || getComputedStyle(nextElement).display;\n"
|
| 1275 |
+
" }\n"
|
| 1276 |
+
" nextElement.style.display = 'none';\n"
|
| 1277 |
+
" } else {\n"
|
| 1278 |
+
" if (nextElement.dataset.originalDisplay) {\n"
|
| 1279 |
+
" const originalDisplay = nextElement.dataset.originalDisplay;\n"
|
| 1280 |
+
" nextElement.style.display = originalDisplay === 'none' ? '' : originalDisplay;\n"
|
| 1281 |
+
" if (nextElement.style.display === originalDisplay && originalDisplay !== 'none') {\n"
|
| 1282 |
+
" nextElement.style.display = '';\n"
|
| 1283 |
+
" }\n"
|
| 1284 |
+
" delete nextElement.dataset.originalDisplay;\n"
|
| 1285 |
+
" } else {\n"
|
| 1286 |
+
" nextElement.style.display = '';\n"
|
| 1287 |
+
" }\n"
|
| 1288 |
+
" }\n"
|
| 1289 |
+
" nextElement = nextElement.nextElementSibling;\n"
|
| 1290 |
+
" }\n"
|
| 1291 |
+
" e.stopPropagation();\n"
|
| 1292 |
+
" });\n"
|
| 1293 |
+
" }\n"
|
| 1294 |
+
" setupRoundCollapse();\n"
|
| 1295 |
+
" const strongHideBtnChat = document.getElementById('toggle-strong-hide');\n"
|
| 1296 |
+
" function applyStrongHideToChat() {\n"
|
| 1297 |
+
" if (!chatFlow) return;\n"
|
| 1298 |
+
" chatFlow.classList.toggle('strong-hide', strongHideOn);\n"
|
| 1299 |
+
" const contextEdits = chatFlow.querySelectorAll('.round-context-edit');\n"
|
| 1300 |
+
" contextEdits.forEach(edit => {\n"
|
| 1301 |
+
" const parent = edit.closest('.round-context, .agent-context-box, .split-agent-context');\n"
|
| 1302 |
+
" if (parent) {\n"
|
| 1303 |
+
" if (strongHideOn && edit.textContent.trim() === '') {\n"
|
| 1304 |
+
" parent.style.display = 'none';\n"
|
| 1305 |
+
" } else {\n"
|
| 1306 |
+
" parent.style.display = '';\n"
|
| 1307 |
+
" }\n"
|
| 1308 |
+
" }\n"
|
| 1309 |
+
" });\n"
|
| 1310 |
+
" const splitContexts = chatFlow.querySelectorAll('.split-agent-context');\n"
|
| 1311 |
+
" splitContexts.forEach(split => {\n"
|
| 1312 |
+
" if (strongHideOn) {\n"
|
| 1313 |
+
" const boxes = split.querySelectorAll('.agent-context-box');\n"
|
| 1314 |
+
" const allEmpty = Array.from(boxes).every(box => {\n"
|
| 1315 |
+
" const edit = box.querySelector('.round-context-edit');\n"
|
| 1316 |
+
" return edit && edit.textContent.trim() === '';\n"
|
| 1317 |
+
" });\n"
|
| 1318 |
+
" if (allEmpty) split.style.display = 'none';\n"
|
| 1319 |
+
" }\n"
|
| 1320 |
+
" });\n"
|
| 1321 |
+
" }\n"
|
| 1322 |
+
" if (strongHideBtnChat && chatFlow) {\n"
|
| 1323 |
+
" strongHideBtnChat.addEventListener('click', () => {\n"
|
| 1324 |
+
" setTimeout(() => applyStrongHideToChat(), 0);\n"
|
| 1325 |
+
" });\n"
|
| 1326 |
+
" }\n"
|
| 1327 |
+
" document.addEventListener('click', function(e) {\n"
|
| 1328 |
+
" if (e.target.closest('input, textarea, select, .round-context-edit, .toolbar')) { return; }\n"
|
| 1329 |
+
" const chatReasoning = e.target.closest('.chat-reasoning');\n"
|
| 1330 |
+
" if (chatReasoning) {\n"
|
| 1331 |
+
" chatReasoning.classList.toggle('collapsed');\n"
|
| 1332 |
+
" }\n"
|
| 1333 |
+
" });\n"
|
| 1334 |
+
" function applyColorToSelection(color, element) {\n"
|
| 1335 |
+
" const selection = window.getSelection();\n"
|
| 1336 |
+
" if (!selection.rangeCount) return false;\n"
|
| 1337 |
+
" const range = selection.getRangeAt(0);\n"
|
| 1338 |
+
" if (!element.contains(range.commonAncestorContainer)) return false;\n"
|
| 1339 |
+
" const selectedText = range.toString();\n"
|
| 1340 |
+
" if (!selectedText) return false;\n"
|
| 1341 |
+
" if (color === 'default') {\n"
|
| 1342 |
+
" // Remove styling - just extract the text content\n"
|
| 1343 |
+
" const textNode = document.createTextNode(selectedText);\n"
|
| 1344 |
+
" range.deleteContents();\n"
|
| 1345 |
+
" range.insertNode(textNode);\n"
|
| 1346 |
+
" } else {\n"
|
| 1347 |
+
" const span = document.createElement('span');\n"
|
| 1348 |
+
" span.style.color = color;\n"
|
| 1349 |
+
" span.style.fontWeight = '600';\n"
|
| 1350 |
+
" try {\n"
|
| 1351 |
+
" range.surroundContents(span);\n"
|
| 1352 |
+
" } catch (e) {\n"
|
| 1353 |
+
" const contents = range.extractContents();\n"
|
| 1354 |
+
" span.appendChild(contents);\n"
|
| 1355 |
+
" range.insertNode(span);\n"
|
| 1356 |
+
" }\n"
|
| 1357 |
+
" }\n"
|
| 1358 |
+
" return true;\n"
|
| 1359 |
+
" }\n"
|
| 1360 |
+
" let lastFocusedContextEdit = null;\n"
|
| 1361 |
+
" document.addEventListener('focusin', function(e) {\n"
|
| 1362 |
+
" if (e.target.classList.contains('round-context-edit')) {\n"
|
| 1363 |
+
" lastFocusedContextEdit = e.target;\n"
|
| 1364 |
+
" }\n"
|
| 1365 |
+
" });\n"
|
| 1366 |
+
" document.addEventListener('mousedown', function(e) {\n"
|
| 1367 |
+
" if (e.target.classList.contains('context-color-btn')) {\n"
|
| 1368 |
+
" e.preventDefault();\n"
|
| 1369 |
+
" }\n"
|
| 1370 |
+
" });\n"
|
| 1371 |
+
" document.addEventListener('click', function(e) {\n"
|
| 1372 |
+
" if (e.target.closest('input:not(.round-context-edit), textarea, select') && !e.target.classList.contains('context-color-btn')) { return; }\n"
|
| 1373 |
+
" if (e.target.classList.contains('context-color-btn')) {\n"
|
| 1374 |
+
" e.preventDefault();\n"
|
| 1375 |
+
" const color = e.target.dataset.color;\n"
|
| 1376 |
+
" const controls = e.target.closest('.round-context-controls');\n"
|
| 1377 |
+
" const contextEdit = controls ? controls.previousElementSibling : null;\n"
|
| 1378 |
+
" if (contextEdit && contextEdit.classList.contains('round-context-edit')) {\n"
|
| 1379 |
+
" contextEdit.focus();\n"
|
| 1380 |
+
" const selection = window.getSelection();\n"
|
| 1381 |
+
" if (selection.rangeCount > 0 && selection.toString().length > 0 && contextEdit.contains(selection.anchorNode)) {\n"
|
| 1382 |
+
" if (applyColorToSelection(color, contextEdit)) {\n"
|
| 1383 |
+
" const key = contextEdit.dataset.contextKey;\n"
|
| 1384 |
+
" localStorage.setItem(key, contextEdit.innerHTML);\n"
|
| 1385 |
+
" }\n"
|
| 1386 |
+
" } else {\n"
|
| 1387 |
+
" try {\n"
|
| 1388 |
+
" if (color !== 'default') {\n"
|
| 1389 |
+
" document.execCommand('styleWithCSS', false, true);\n"
|
| 1390 |
+
" document.execCommand('foreColor', false, color);\n"
|
| 1391 |
+
" }\n"
|
| 1392 |
+
" const key = contextEdit.dataset.contextKey;\n"
|
| 1393 |
+
" setTimeout(() => localStorage.setItem(key, contextEdit.innerHTML), 10);\n"
|
| 1394 |
+
" } catch (e) {\n"
|
| 1395 |
+
" console.log('Color command failed:', e);\n"
|
| 1396 |
+
" }\n"
|
| 1397 |
+
" }\n"
|
| 1398 |
+
" }\n"
|
| 1399 |
+
" }\n"
|
| 1400 |
+
" });\n"
|
| 1401 |
+
" const contextEdits = document.querySelectorAll('.round-context-edit');\n"
|
| 1402 |
+
" contextEdits.forEach(edit => {\n"
|
| 1403 |
+
" edit.addEventListener('input', function() {\n"
|
| 1404 |
+
" const key = this.dataset.contextKey;\n"
|
| 1405 |
+
" localStorage.setItem(key, this.innerHTML);\n"
|
| 1406 |
+
" });\n"
|
| 1407 |
+
" const key = edit.dataset.contextKey;\n"
|
| 1408 |
+
" const saved = localStorage.getItem(key);\n"
|
| 1409 |
+
" if (saved) {\n"
|
| 1410 |
+
" edit.innerHTML = saved;\n"
|
| 1411 |
+
" }\n"
|
| 1412 |
+
" });\n"
|
| 1413 |
+
" document.addEventListener('click', function(e) {\n"
|
| 1414 |
+
" if (e.target.closest('input, textarea, select, .round-context-edit') && !e.target.classList.contains('merge-btn') && !e.target.classList.contains('unmerge-btn')) { return; }\n"
|
| 1415 |
+
" if (e.target.classList.contains('merge-btn')) {\n"
|
| 1416 |
+
" e.preventDefault();\n"
|
| 1417 |
+
" e.stopPropagation();\n"
|
| 1418 |
+
" const msgId = e.target.dataset.msgId;\n"
|
| 1419 |
+
" const currentMsg = e.target.closest('.chat-message');\n"
|
| 1420 |
+
" if (!currentMsg) return;\n"
|
| 1421 |
+
" if (currentMsg.classList.contains('role-user')) {\n"
|
| 1422 |
+
" alert('Cannot merge user messages');\n"
|
| 1423 |
+
" return;\n"
|
| 1424 |
+
" }\n"
|
| 1425 |
+
" let nextMsg = currentMsg.nextElementSibling;\n"
|
| 1426 |
+
" while (nextMsg && !nextMsg.classList.contains('chat-message')) {\n"
|
| 1427 |
+
" nextMsg = nextMsg.nextElementSibling;\n"
|
| 1428 |
+
" }\n"
|
| 1429 |
+
" while (nextMsg && nextMsg.classList.contains('role-user')) {\n"
|
| 1430 |
+
" nextMsg = nextMsg.nextElementSibling;\n"
|
| 1431 |
+
" while (nextMsg && !nextMsg.classList.contains('chat-message')) {\n"
|
| 1432 |
+
" nextMsg = nextMsg.nextElementSibling;\n"
|
| 1433 |
+
" }\n"
|
| 1434 |
+
" }\n"
|
| 1435 |
+
" if (!nextMsg || nextMsg.classList.contains('chat-message') === false) {\n"
|
| 1436 |
+
" alert('No next assistant message to merge with');\n"
|
| 1437 |
+
" return;\n"
|
| 1438 |
+
" }\n"
|
| 1439 |
+
" if (nextMsg.classList.contains('role-user')) {\n"
|
| 1440 |
+
" alert('Cannot merge with user messages');\n"
|
| 1441 |
+
" return;\n"
|
| 1442 |
+
" }\n"
|
| 1443 |
+
" const parent = currentMsg.parentElement;\n"
|
| 1444 |
+
" if (parent.classList.contains('simultaneous-messages')) {\n"
|
| 1445 |
+
" const wrapper = parent;\n"
|
| 1446 |
+
" currentMsg.style.display = '';\n"
|
| 1447 |
+
" currentMsg.classList.remove('merged');\n"
|
| 1448 |
+
" const refNode = wrapper.nextElementSibling;\n"
|
| 1449 |
+
" parent.parentElement.insertBefore(currentMsg, refNode);\n"
|
| 1450 |
+
" if (nextMsg.parentElement === wrapper) {\n"
|
| 1451 |
+
" parent.parentElement.insertBefore(nextMsg, refNode);\n"
|
| 1452 |
+
" }\n"
|
| 1453 |
+
" if (wrapper.children.length === 0) {\n"
|
| 1454 |
+
" wrapper.remove();\n"
|
| 1455 |
+
" }\n"
|
| 1456 |
+
" } else {\n"
|
| 1457 |
+
" const wrapper = document.createElement('div');\n"
|
| 1458 |
+
" wrapper.className = 'simultaneous-messages';\n"
|
| 1459 |
+
" const unmergeBtn = document.createElement('button');\n"
|
| 1460 |
+
" unmergeBtn.className = 'unmerge-btn';\n"
|
| 1461 |
+
" unmergeBtn.innerHTML = '✕';\n"
|
| 1462 |
+
" unmergeBtn.title = 'Click to unmerge messages';\n"
|
| 1463 |
+
" wrapper.appendChild(unmergeBtn);\n"
|
| 1464 |
+
" wrapper.dataset.firstMsgId = currentMsg.dataset.msgId;\n"
|
| 1465 |
+
" wrapper.dataset.secondMsgId = nextMsg.dataset.msgId;\n"
|
| 1466 |
+
" parent.insertBefore(wrapper, currentMsg);\n"
|
| 1467 |
+
" wrapper.appendChild(currentMsg);\n"
|
| 1468 |
+
" wrapper.appendChild(nextMsg);\n"
|
| 1469 |
+
" currentMsg.classList.add('merged');\n"
|
| 1470 |
+
" nextMsg.classList.add('merged');\n"
|
| 1471 |
+
" }\n"
|
| 1472 |
+
" }\n"
|
| 1473 |
+
" if (e.target.classList.contains('unmerge-btn')) {\n"
|
| 1474 |
+
" const wrapper = e.target.closest('.simultaneous-messages');\n"
|
| 1475 |
+
" if (!wrapper) return;\n"
|
| 1476 |
+
" const parent = wrapper.parentElement;\n"
|
| 1477 |
+
" const firstMsgId = wrapper.dataset.firstMsgId;\n"
|
| 1478 |
+
" const secondMsgId = wrapper.dataset.secondMsgId;\n"
|
| 1479 |
+
" const messages = Array.from(wrapper.querySelectorAll('.chat-message'));\n"
|
| 1480 |
+
" const refNode = wrapper.nextElementSibling;\n"
|
| 1481 |
+
" const firstMsg = messages.find(m => m.dataset.msgId === firstMsgId);\n"
|
| 1482 |
+
" const secondMsg = messages.find(m => m.dataset.msgId === secondMsgId);\n"
|
| 1483 |
+
" if (firstMsg) {\n"
|
| 1484 |
+
" firstMsg.classList.remove('merged');\n"
|
| 1485 |
+
" firstMsg.style.display = '';\n"
|
| 1486 |
+
" parent.insertBefore(firstMsg, refNode);\n"
|
| 1487 |
+
" }\n"
|
| 1488 |
+
" if (secondMsg) {\n"
|
| 1489 |
+
" secondMsg.classList.remove('merged');\n"
|
| 1490 |
+
" secondMsg.style.display = '';\n"
|
| 1491 |
+
" parent.insertBefore(secondMsg, refNode);\n"
|
| 1492 |
+
" }\n"
|
| 1493 |
+
" wrapper.remove();\n"
|
| 1494 |
+
" }\n"
|
| 1495 |
+
" });\n"
|
| 1496 |
+
"});\n"
|
| 1497 |
+
"</script>",
|
| 1498 |
+
"</head>",
|
| 1499 |
+
"<body>",
|
| 1500 |
+
'<div class="toolbar-wrap">',
|
| 1501 |
+
'<div class="toolbar-hotzone"></div>',
|
| 1502 |
+
'<div class="toolbar">',
|
| 1503 |
+
'<label for="group-size">Group every</label>',
|
| 1504 |
+
'<input id="group-size" type="number" min="0" step="1" value="1" />',
|
| 1505 |
+
"<span>timesteps</span>",
|
| 1506 |
+
'<button id="apply-grouping">Apply</button>',
|
| 1507 |
+
'<span style="margin-left:8px"></span>',
|
| 1508 |
+
'<label for="range-start"><span class="emoji-bw">🔎</span> Range</label>',
|
| 1509 |
+
'<input id="range-start" type="number" step="1" />',
|
| 1510 |
+
"<span>to</span>",
|
| 1511 |
+
'<input id="range-end" type="number" step="1" />',
|
| 1512 |
+
'<button id="apply-range"><span class="emoji-bw">▶︎</span> Apply</button>',
|
| 1513 |
+
'<button id="toggle-strong-hide"><span class="emoji-bw">🗜️</span> Strong Hide: <span id="strong-hide-state">Off</span></button>',
|
| 1514 |
+
(
|
| 1515 |
+
'<button id="toggle-split-view"><span class="emoji-bw">🪟</span> Split View: <span id="split-view-state">Off</span></button>'
|
| 1516 |
+
if enable_split_view
|
| 1517 |
+
else ""
|
| 1518 |
+
),
|
| 1519 |
+
'<button id="toggle-chat-view"><span class="emoji-bw">💬</span> Chat View: <span id="chat-view-state">On</span></button>',
|
| 1520 |
+
'<button id="toggle-hide-user-messages"><span class="emoji-bw">👁️</span> Hide Prompts: <span id="hide-user-state">Off</span></button>',
|
| 1521 |
+
'<span id="chat-width-control" style="margin-left:8px;">',
|
| 1522 |
+
'<label for="chat-width-slider"><span class="emoji-bw">↔️</span> Width:</label>',
|
| 1523 |
+
'<input id="chat-width-slider" type="range" min="600" max="1600" step="50" value="900" style="width:120px; vertical-align:middle;" />',
|
| 1524 |
+
'<span id="chat-width-value" style="margin-left:4px;">900px</span>',
|
| 1525 |
+
'</span>',
|
| 1526 |
+
'<span style="margin-left:12px;">',
|
| 1527 |
+
'<label for="font-family-select"><span class="emoji-bw">🔤</span> Font:</label>',
|
| 1528 |
+
'<select id="font-family-select" style="padding:2px 6px; border:1px solid var(--accent-muted); border-radius:var(--corner-radius); background:var(--bg);">',
|
| 1529 |
+
'<option value="\'Segoe UI\', Tahoma, Geneva, Verdana, sans-serif">Segoe UI</option>',
|
| 1530 |
+
'<option value="Arial, sans-serif">Arial</option>',
|
| 1531 |
+
'<option value="\'Helvetica Neue\', Helvetica, sans-serif">Helvetica</option>',
|
| 1532 |
+
'<option value="\'Times New Roman\', Times, serif">Times New Roman</option>',
|
| 1533 |
+
'<option value="Georgia, serif">Georgia</option>',
|
| 1534 |
+
'<option value="\'Courier New\', Courier, monospace">Courier New</option>',
|
| 1535 |
+
'<option value="\'Comic Sans MS\', cursive">Comic Sans</option>',
|
| 1536 |
+
'<option value="\'Trebuchet MS\', sans-serif">Trebuchet MS</option>',
|
| 1537 |
+
'<option value="Verdana, sans-serif">Verdana</option>',
|
| 1538 |
+
'<option value="\'Palatino Linotype\', \'Book Antiqua\', Palatino, serif">Palatino</option>',
|
| 1539 |
+
'<option value="\'Lucida Console\', Monaco, monospace">Lucida Console</option>',
|
| 1540 |
+
'</select>',
|
| 1541 |
+
'</span>',
|
| 1542 |
+
'<span style="margin-left:8px;">',
|
| 1543 |
+
'<label for="font-size-input"><span class="emoji-bw">📏</span> Size:</label>',
|
| 1544 |
+
'<input id="font-size-input" type="number" min="8" max="24" step="1" value="14" style="width:50px;" />',
|
| 1545 |
+
'<span>px</span>',
|
| 1546 |
+
'</span>',
|
| 1547 |
+
'<span style="margin-left:12px; display:flex; align-items:center; gap:8px;">',
|
| 1548 |
+
'<label style="font-weight:600;">Agent Names:</label>',
|
| 1549 |
+
'<select id="alice-emoji-input" style="width:65px; padding:2px 6px; border:1px solid var(--accent-muted); border-radius:var(--corner-radius); background:var(--bg);">',
|
| 1550 |
+
'<option value="🤖">🤖 Robot</option>',
|
| 1551 |
+
'<option value="👤">👤 Human</option>',
|
| 1552 |
+
'</select>',
|
| 1553 |
+
'<input id="alice-name-input" type="text" placeholder="Alice" style="width:80px; padding:2px 6px; border:1px solid var(--accent-muted); border-radius:var(--corner-radius); background:var(--bg);" />',
|
| 1554 |
+
'<span style="margin:0 4px;">|</span>',
|
| 1555 |
+
'<select id="bob-emoji-input" style="width:65px; padding:2px 6px; border:1px solid var(--accent-muted); border-radius:var(--corner-radius); background:var(--bg);">',
|
| 1556 |
+
'<option value="🤖">🤖 Robot</option>',
|
| 1557 |
+
'<option value="👤">👤 Human</option>',
|
| 1558 |
+
'</select>',
|
| 1559 |
+
'<input id="bob-name-input" type="text" placeholder="Bob" style="width:80px; padding:2px 6px; border:1px solid var(--accent-muted); border-radius:var(--corner-radius); background:var(--bg);" />',
|
| 1560 |
+
'<button id="apply-agent-names" style="padding:4px 8px; border:1px solid var(--accent-muted); background:var(--panel-bg); border-radius:var(--corner-radius); cursor:pointer;">Apply</button>',
|
| 1561 |
+
'</span>',
|
| 1562 |
+
"</div>",
|
| 1563 |
+
"</div>",
|
| 1564 |
+
'<div id="flow-linear" class="messages-flow" style="display:none">',
|
| 1565 |
+
]
|
| 1566 |
+
|
| 1567 |
+
last_time_step = None
|
| 1568 |
+
for original_index, turn in indexed_turns:
|
| 1569 |
+
# Build classes
|
| 1570 |
+
agent_class = f"agent-{re.sub('[^a-z0-9_-]', '-', turn.agent_id.lower())}"
|
| 1571 |
+
role_class = f"role-{turn.role}"
|
| 1572 |
+
collapsed_class = " collapsed" if turn.role == "user" else ""
|
| 1573 |
+
|
| 1574 |
+
# Badge content
|
| 1575 |
+
agent_id_clean = html.escape(turn.agent_id).lower()
|
| 1576 |
+
if turn.role == "assistant":
|
| 1577 |
+
name = html.escape(turn.agent_id)
|
| 1578 |
+
emoji = '<span class="emoji-bw" data-agent-id="' + agent_id_clean + '"> 🤖</span>'
|
| 1579 |
+
raw_val = turn.reward
|
| 1580 |
+
if isinstance(raw_val, (int, float)):
|
| 1581 |
+
reward_val = f"{raw_val:.4f}".rstrip("0").rstrip(".")
|
| 1582 |
+
if len(reward_val) > 8:
|
| 1583 |
+
reward_val = reward_val[:8] + "…"
|
| 1584 |
+
else:
|
| 1585 |
+
reward_val = str(raw_val)
|
| 1586 |
+
# Format: "🤖 Alice • Reward: 5.5556 • 💬 :"
|
| 1587 |
+
badge_inner = (
|
| 1588 |
+
f'{emoji} <span class="agent-name" data-agent-id="{agent_id_clean}">{name}</span>'
|
| 1589 |
+
f' <span class="sep"> • </span><span class="reward">Reward ⚑ = {reward_val}</span>'
|
| 1590 |
+
)
|
| 1591 |
+
else:
|
| 1592 |
+
# For user messages, show "Prompt of {Agent ID}" in the badge
|
| 1593 |
+
name = html.escape(turn.agent_id)
|
| 1594 |
+
# Format (no reward): "Prompt of Alice • "
|
| 1595 |
+
badge_inner = f'Prompt of <span class="agent-name" data-agent-id="{agent_id_clean}">{name}</span> <span class="sep"> • </span>:'
|
| 1596 |
+
|
| 1597 |
+
badge = f'<span class="agent-badge">{badge_inner}</span>'
|
| 1598 |
+
|
| 1599 |
+
# Inline timestep distinction badge at step boundaries (render before first message)
|
| 1600 |
+
ts_badge_html = ""
|
| 1601 |
+
if last_time_step is None or turn.time_step != last_time_step:
|
| 1602 |
+
ts_badge_html = f'<span class="ts-badge">⏱ {turn.time_step}</span>'
|
| 1603 |
+
last_time_step = turn.time_step
|
| 1604 |
+
|
| 1605 |
+
escaped_content = html.escape(turn.content)
|
| 1606 |
+
reasoning_html = ""
|
| 1607 |
+
if turn.reasoning_content:
|
| 1608 |
+
# Normalize reasoning to avoid leading/newline whitespace that creates visual gaps
|
| 1609 |
+
_raw_reasoning = turn.reasoning_content.replace("\r\n", "\n")
|
| 1610 |
+
_raw_reasoning = _re.sub(
|
| 1611 |
+
r"^\s*\n+", "", _raw_reasoning
|
| 1612 |
+
) # drop leading blank lines
|
| 1613 |
+
_raw_reasoning = _re.sub(
|
| 1614 |
+
r"\*\*(\s*\n\s*)", r"** ", _raw_reasoning
|
| 1615 |
+
) # newline right after **
|
| 1616 |
+
_raw_reasoning = _re.sub(
|
| 1617 |
+
r"(\s*\n\s*)\*\*", r" **", _raw_reasoning
|
| 1618 |
+
) # newline right before **
|
| 1619 |
+
escaped_reasoning = html.escape(_raw_reasoning)
|
| 1620 |
+
reasoning_html = f'<span class="reasoning-inline"><span class="reasoning-icon">💭</span><span class="reasoning-text">{escaped_reasoning}</span></span>'
|
| 1621 |
+
collapsed_text = re.sub(r"\s+", " ", escaped_content).strip()
|
| 1622 |
+
|
| 1623 |
+
html_parts.append(
|
| 1624 |
+
f'<div class="chat-turn {agent_class} {role_class}{collapsed_class}" data-time-step="{turn.time_step}">'
|
| 1625 |
+
f'<div class="turn-content {agent_class} {role_class}">{ts_badge_html}{badge}'
|
| 1626 |
+
f'<span class="message-box">{reasoning_html}<span class="main-content">💬 {collapsed_text}</span></span>'
|
| 1627 |
+
f'<span class="message-placeholder">(...)</span>'
|
| 1628 |
+
f"</div>"
|
| 1629 |
+
f"</div>"
|
| 1630 |
+
)
|
| 1631 |
+
|
| 1632 |
+
html_parts.append("</div>") # close linear flow
|
| 1633 |
+
if enable_split_view:
|
| 1634 |
+
import html as _html_mod
|
| 1635 |
+
|
| 1636 |
+
html_parts.append(
|
| 1637 |
+
'<div id="flow-split" class="messages-flow" style="display:none">'
|
| 1638 |
+
)
|
| 1639 |
+
html_parts.append('<div class="split-wrapper">')
|
| 1640 |
+
# Per-agent columns
|
| 1641 |
+
per_agent_turns = {
|
| 1642 |
+
aid: [t for t in chat_turns if t.agent_id == aid]
|
| 1643 |
+
for aid in assistant_agents
|
| 1644 |
+
}
|
| 1645 |
+
for idx, aid in enumerate(assistant_agents):
|
| 1646 |
+
turns_agent = per_agent_turns[aid]
|
| 1647 |
+
html_parts.append(
|
| 1648 |
+
f'<div class="split-col" data-agent="{_html_mod.escape(aid)}">'
|
| 1649 |
+
)
|
| 1650 |
+
last_ts_agent = None
|
| 1651 |
+
for turn in turns_agent:
|
| 1652 |
+
agent_class = (
|
| 1653 |
+
f"agent-{re.sub('[^a-z0-9_-]', '-', turn.agent_id.lower())}"
|
| 1654 |
+
)
|
| 1655 |
+
role_class = f"role-{turn.role}"
|
| 1656 |
+
collapsed_class = " collapsed" if turn.role == "user" else ""
|
| 1657 |
+
ts_badge_html = ""
|
| 1658 |
+
if last_ts_agent is None or turn.time_step != last_ts_agent:
|
| 1659 |
+
ts_badge_html = f'<span class="ts-badge">⏱ {turn.time_step}</span>'
|
| 1660 |
+
last_ts_agent = turn.time_step
|
| 1661 |
+
esc_content = _html_mod.escape(turn.content)
|
| 1662 |
+
reasoning_html = ""
|
| 1663 |
+
if turn.reasoning_content:
|
| 1664 |
+
_raw_reasoning = turn.reasoning_content.replace("\r\n", "\n")
|
| 1665 |
+
_raw_reasoning = _re.sub(r"^\s*\n+", "", _raw_reasoning)
|
| 1666 |
+
_raw_reasoning = _re.sub(r"\*\*(\s*\n\s*)", r"** ", _raw_reasoning)
|
| 1667 |
+
_raw_reasoning = _re.sub(r"(\s*\n\s*)\*\*", r" **", _raw_reasoning)
|
| 1668 |
+
esc_reasoning = _html_mod.escape(_raw_reasoning)
|
| 1669 |
+
reasoning_html = f'<span class="reasoning-inline"><span class="reasoning-icon">💭</span><span class="reasoning-text">{esc_reasoning}</span></span>'
|
| 1670 |
+
collapsed_text = re.sub(r"\s+", " ", esc_content).strip()
|
| 1671 |
+
agent_id_clean = _html_mod.escape(turn.agent_id).lower()
|
| 1672 |
+
if turn.role == "assistant":
|
| 1673 |
+
name = _html_mod.escape(turn.agent_id)
|
| 1674 |
+
emoji = '<span class="emoji-bw" data-agent-id="' + agent_id_clean + '"> 🤖</span>'
|
| 1675 |
+
raw_val = turn.reward
|
| 1676 |
+
if isinstance(raw_val, (int, float)):
|
| 1677 |
+
reward_val = f"{raw_val:.4f}".rstrip("0").rstrip(".")
|
| 1678 |
+
if len(reward_val) > 8:
|
| 1679 |
+
reward_val = reward_val[:8] + "…"
|
| 1680 |
+
else:
|
| 1681 |
+
reward_val = str(raw_val)
|
| 1682 |
+
badge_inner = (
|
| 1683 |
+
f'{emoji} <span class="agent-name" data-agent-id="{agent_id_clean}">{name}</span>'
|
| 1684 |
+
f' <span class="sep"> • </span><span class="reward">Reward ⚑ : {reward_val}</span>'
|
| 1685 |
+
)
|
| 1686 |
+
else:
|
| 1687 |
+
name = _html_mod.escape(turn.agent_id)
|
| 1688 |
+
badge_inner = f'Prompt of <span class="agent-name" data-agent-id="{agent_id_clean}">{name}</span> <span class="sep"> • </span>:'
|
| 1689 |
+
badge = f'<span class="agent-badge">{badge_inner}</span>'
|
| 1690 |
+
html_parts.append(
|
| 1691 |
+
f'<div class="chat-turn {agent_class} {role_class}{collapsed_class}" data-time-step="{turn.time_step}">'
|
| 1692 |
+
f'<div class="turn-content {agent_class} {role_class}">{ts_badge_html}{badge}'
|
| 1693 |
+
f'<span class="message-box">{reasoning_html}<span class="main-content">💬 {collapsed_text}</span></span>'
|
| 1694 |
+
f'<span class="message-placeholder">(...)</span>'
|
| 1695 |
+
f"</div></div>"
|
| 1696 |
+
)
|
| 1697 |
+
html_parts.append("</div>") # close split col
|
| 1698 |
+
html_parts.append("</div>") # split-wrapper
|
| 1699 |
+
html_parts.append("</div>") # flow-split
|
| 1700 |
+
|
| 1701 |
+
# Add Chat View
|
| 1702 |
+
import html as _html_mod
|
| 1703 |
+
html_parts.append('<div id="flow-chat" class="messages-flow">')
|
| 1704 |
+
|
| 1705 |
+
# Helper function to add context annotation areas
|
| 1706 |
+
def add_context_area(position: str, time_step: int):
|
| 1707 |
+
context_key = f"round-context-{position}-{time_step}"
|
| 1708 |
+
placeholder = f"Add context {position} round {time_step}..."
|
| 1709 |
+
color_buttons = ""
|
| 1710 |
+
# Add default/reset color button first
|
| 1711 |
+
color_buttons += (
|
| 1712 |
+
f'<div class="context-color-btn" data-color="default" '
|
| 1713 |
+
f'style="background: linear-gradient(135deg, #000 25%, transparent 25%, transparent 75%, #000 75%), '
|
| 1714 |
+
f'linear-gradient(135deg, #000 25%, transparent 25%, transparent 75%, #000 75%); '
|
| 1715 |
+
f'background-size: 4px 4px; background-position: 0 0, 2px 2px; '
|
| 1716 |
+
f'background-color: #fff;" title="Default color"></div>'
|
| 1717 |
+
)
|
| 1718 |
+
for color_name, color_value in [
|
| 1719 |
+
('red', '#d32f2f'),
|
| 1720 |
+
('orange', '#f57c00'),
|
| 1721 |
+
('yellow', '#f9a825'),
|
| 1722 |
+
('green', '#388e3c'),
|
| 1723 |
+
('blue', '#1976d2'),
|
| 1724 |
+
('purple', '#7b1fa2'),
|
| 1725 |
+
('gray', '#666666'),
|
| 1726 |
+
]:
|
| 1727 |
+
color_buttons += (
|
| 1728 |
+
f'<div class="context-color-btn" data-color="{color_value}" '
|
| 1729 |
+
f'style="background-color: {color_value};" title="{color_name}"></div>'
|
| 1730 |
+
)
|
| 1731 |
+
|
| 1732 |
+
html_parts.append(
|
| 1733 |
+
f'<div class="round-context">'
|
| 1734 |
+
f'<div class="round-context-edit" contenteditable="true" spellcheck="true" '
|
| 1735 |
+
f'data-context-key="{context_key}" '
|
| 1736 |
+
f'data-placeholder="{placeholder}"></div>'
|
| 1737 |
+
f'<div class="round-context-controls">{color_buttons}</div>'
|
| 1738 |
+
f'</div>'
|
| 1739 |
+
)
|
| 1740 |
+
|
| 1741 |
+
# Helper function to add split agent context boxes
|
| 1742 |
+
def add_split_agent_contexts(position: str, time_step: int):
|
| 1743 |
+
color_buttons = ""
|
| 1744 |
+
# Add default/reset color button first
|
| 1745 |
+
color_buttons += (
|
| 1746 |
+
f'<div class="context-color-btn" data-color="default" '
|
| 1747 |
+
f'style="background: linear-gradient(135deg, #000 25%, transparent 25%, transparent 75%, #000 75%), '
|
| 1748 |
+
f'linear-gradient(135deg, #000 25%, transparent 25%, transparent 75%, #000 75%); '
|
| 1749 |
+
f'background-size: 4px 4px; background-position: 0 0, 2px 2px; '
|
| 1750 |
+
f'background-color: #fff;" title="Default color"></div>'
|
| 1751 |
+
)
|
| 1752 |
+
for color_name, color_value in [
|
| 1753 |
+
('red', '#d32f2f'),
|
| 1754 |
+
('orange', '#f57c00'),
|
| 1755 |
+
('yellow', '#f9a825'),
|
| 1756 |
+
('green', '#388e3c'),
|
| 1757 |
+
('blue', '#1976d2'),
|
| 1758 |
+
('purple', '#7b1fa2'),
|
| 1759 |
+
('gray', '#666666'),
|
| 1760 |
+
]:
|
| 1761 |
+
color_buttons += (
|
| 1762 |
+
f'<div class="context-color-btn" data-color="{color_value}" '
|
| 1763 |
+
f'style="background-color: {color_value};" title="{color_name}"></div>'
|
| 1764 |
+
)
|
| 1765 |
+
|
| 1766 |
+
html_parts.append('<div class="split-agent-context">')
|
| 1767 |
+
|
| 1768 |
+
# Alice box
|
| 1769 |
+
alice_key = f"agent-context-alice-{position}-{time_step}"
|
| 1770 |
+
alice_placeholder = f"..."
|
| 1771 |
+
html_parts.append(
|
| 1772 |
+
f'<div class="agent-context-box agent-alice">'
|
| 1773 |
+
f'<div class="round-context-edit" contenteditable="true" spellcheck="true" '
|
| 1774 |
+
f'data-context-key="{alice_key}" '
|
| 1775 |
+
f'data-placeholder="{alice_placeholder}"></div>'
|
| 1776 |
+
f'<div class="round-context-controls">{color_buttons}</div>'
|
| 1777 |
+
f'</div>'
|
| 1778 |
+
)
|
| 1779 |
+
|
| 1780 |
+
# Bob box
|
| 1781 |
+
bob_key = f"agent-context-bob-{position}-{time_step}"
|
| 1782 |
+
bob_placeholder = f"..."
|
| 1783 |
+
html_parts.append(
|
| 1784 |
+
f'<div class="agent-context-box agent-bob">'
|
| 1785 |
+
f'<div class="round-context-edit" contenteditable="true" spellcheck="true" '
|
| 1786 |
+
f'data-context-key="{bob_key}" '
|
| 1787 |
+
f'data-placeholder="{bob_placeholder}"></div>'
|
| 1788 |
+
f'<div class="round-context-controls">{color_buttons}</div>'
|
| 1789 |
+
f'</div>'
|
| 1790 |
+
)
|
| 1791 |
+
|
| 1792 |
+
html_parts.append('</div>') # split-agent-context
|
| 1793 |
+
|
| 1794 |
+
last_time_step_chat = None
|
| 1795 |
+
for original_index, turn in indexed_turns:
|
| 1796 |
+
agent_class = f"agent-{re.sub('[^a-z0-9_-]', '-', turn.agent_id.lower())}"
|
| 1797 |
+
role_class = f"role-{turn.role}"
|
| 1798 |
+
|
| 1799 |
+
# Add time step divider and beginning context
|
| 1800 |
+
if last_time_step_chat is None or turn.time_step != last_time_step_chat:
|
| 1801 |
+
# Add end contexts for previous round (only regular context, not prompt summary)
|
| 1802 |
+
if last_time_step_chat is not None:
|
| 1803 |
+
add_context_area("end", last_time_step_chat)
|
| 1804 |
+
|
| 1805 |
+
html_parts.append(
|
| 1806 |
+
f'<div class="chat-group-divider">'
|
| 1807 |
+
f'<span class="chat-group-label">⏱ Round {turn.time_step + 1}</span>'
|
| 1808 |
+
f'</div>'
|
| 1809 |
+
)
|
| 1810 |
+
|
| 1811 |
+
# Add beginning contexts for new round (both context and prompt summary)
|
| 1812 |
+
add_context_area("beginning", turn.time_step)
|
| 1813 |
+
add_split_agent_contexts("beginning", turn.time_step)
|
| 1814 |
+
|
| 1815 |
+
last_time_step_chat = turn.time_step
|
| 1816 |
+
|
| 1817 |
+
# Build chat message with merge controls
|
| 1818 |
+
html_parts.append(f'<div class="chat-message {agent_class} {role_class}" data-msg-id="{original_index}">')
|
| 1819 |
+
|
| 1820 |
+
# Add merge control button
|
| 1821 |
+
html_parts.append(
|
| 1822 |
+
f'<button class="merge-btn" title="Merge with next message" data-msg-id="{original_index}">⇄</button>'
|
| 1823 |
+
)
|
| 1824 |
+
|
| 1825 |
+
html_parts.append('<div class="chat-message-content">')
|
| 1826 |
+
|
| 1827 |
+
# Header with agent name and reward (always show reward)
|
| 1828 |
+
agent_id_clean = _html_mod.escape(turn.agent_id).lower()
|
| 1829 |
+
if turn.role == "assistant":
|
| 1830 |
+
name = _html_mod.escape(turn.agent_id)
|
| 1831 |
+
raw_val = turn.reward
|
| 1832 |
+
if isinstance(raw_val, (int, float)):
|
| 1833 |
+
reward_val = f"{raw_val:.4f}".rstrip("0").rstrip(".")
|
| 1834 |
+
if len(reward_val) > 8:
|
| 1835 |
+
reward_val = reward_val[:8] + "…"
|
| 1836 |
+
else:
|
| 1837 |
+
reward_val = str(raw_val)
|
| 1838 |
+
header_html = (
|
| 1839 |
+
f'<div class="chat-header">'
|
| 1840 |
+
f'<span class="emoji-bw" data-agent-id="{agent_id_clean}">🤖</span> <span class="agent-name" data-agent-id="{agent_id_clean}">{name}</span>'
|
| 1841 |
+
f'<span class="chat-reward">⚑ {reward_val}</span>'
|
| 1842 |
+
f'</div>'
|
| 1843 |
+
)
|
| 1844 |
+
else:
|
| 1845 |
+
name = _html_mod.escape(turn.agent_id)
|
| 1846 |
+
header_html = f'<div class="chat-header">Prompt of <span class="agent-name" data-agent-id="{agent_id_clean}">{name}</span></div>'
|
| 1847 |
+
|
| 1848 |
+
html_parts.append(header_html)
|
| 1849 |
+
|
| 1850 |
+
# Reasoning content if present
|
| 1851 |
+
if turn.reasoning_content:
|
| 1852 |
+
_raw_reasoning = turn.reasoning_content.replace("\r\n", "\n")
|
| 1853 |
+
_raw_reasoning = _re.sub(r"^\s*\n+", "", _raw_reasoning)
|
| 1854 |
+
esc_reasoning = _html_mod.escape(_raw_reasoning)
|
| 1855 |
+
html_parts.append(
|
| 1856 |
+
f'<div class="chat-reasoning collapsed">'
|
| 1857 |
+
f'<span class="reasoning-icon">💭</span> '
|
| 1858 |
+
f'<span class="reasoning-text">{esc_reasoning}</span>'
|
| 1859 |
+
f'</div>'
|
| 1860 |
+
)
|
| 1861 |
+
|
| 1862 |
+
# Message bubble
|
| 1863 |
+
esc_content = _html_mod.escape(turn.content)
|
| 1864 |
+
html_parts.append(f'<div class="chat-bubble">{esc_content}</div>')
|
| 1865 |
+
|
| 1866 |
+
html_parts.append('</div>') # chat-message-content
|
| 1867 |
+
html_parts.append('</div>') # chat-message
|
| 1868 |
+
|
| 1869 |
+
# Add end contexts for the last round (only regular context, not prompt summary)
|
| 1870 |
+
if last_time_step_chat is not None:
|
| 1871 |
+
add_context_area("end", last_time_step_chat)
|
| 1872 |
+
|
| 1873 |
+
html_parts.append("</div>") # flow-chat
|
| 1874 |
+
html_parts.extend(["</body>", "</html>"])
|
| 1875 |
+
|
| 1876 |
+
return "\n".join(html_parts)
|
| 1877 |
+
|
| 1878 |
+
|
| 1879 |
+
def export_html_from_rollout_tree(path: Path, outdir: Path, main_only: bool = False):
|
| 1880 |
+
"""Process a rollout tree file and generate HTML files for each path.
|
| 1881 |
+
Creates separate HTML files for the main path and each branch path.
|
| 1882 |
+
The main path is saved in the root output directory, while branch paths
|
| 1883 |
+
are saved in a 'branches' subdirectory.
|
| 1884 |
+
|
| 1885 |
+
Args:
|
| 1886 |
+
path: Path to the rollout tree JSON file
|
| 1887 |
+
outdir: Output directory for HTML files
|
| 1888 |
+
main_only: If True, only export the main trajectory (default: False)
|
| 1889 |
+
"""
|
| 1890 |
+
root = load_rollout_tree(path)
|
| 1891 |
+
mgid = root.id
|
| 1892 |
+
|
| 1893 |
+
main_path, branch_paths = get_rollout_tree_paths(root)
|
| 1894 |
+
|
| 1895 |
+
outdir.mkdir(parents=True, exist_ok=True)
|
| 1896 |
+
|
| 1897 |
+
# Create branches subdirectory if we have branch paths
|
| 1898 |
+
if not main_only and branch_paths:
|
| 1899 |
+
branches_dir = outdir / f"mgid:{mgid}_branches_html_renders"
|
| 1900 |
+
branches_dir.mkdir(parents=True, exist_ok=True)
|
| 1901 |
+
|
| 1902 |
+
# Generate HTML for the main path
|
| 1903 |
+
chat_turns = gather_all_chat_turns_for_path(main_path)
|
| 1904 |
+
html_content = html_from_chat_turns(chat_turns)
|
| 1905 |
+
output_file = outdir / f"mgid:{mgid}_main_html_render.render.html"
|
| 1906 |
+
with open(output_file, "w", encoding="utf-8") as f:
|
| 1907 |
+
f.write(html_content)
|
| 1908 |
+
|
| 1909 |
+
# Generate HTML for each branch path
|
| 1910 |
+
for path_obj in branch_paths:
|
| 1911 |
+
chat_turns = gather_all_chat_turns_for_path(path_obj)
|
| 1912 |
+
|
| 1913 |
+
html_content = html_from_chat_turns(chat_turns)
|
| 1914 |
+
|
| 1915 |
+
path_id: str = path_obj.id
|
| 1916 |
+
output_filename = f"{path_id}_html_render.render.html"
|
| 1917 |
+
|
| 1918 |
+
output_file = branches_dir / output_filename
|
| 1919 |
+
|
| 1920 |
+
with open(output_file, "w", encoding="utf-8") as f:
|
| 1921 |
+
f.write(html_content)
|
src_code_for_reproducibility/utils/rollout_tree_stats.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Callable, List, Tuple
|
| 2 |
+
|
| 3 |
+
from mllm.markov_games.rollout_tree import RolloutTreeRootNode
|
| 4 |
+
from mllm.markov_games.simulation import SimulationStepLog
|
| 5 |
+
from mllm.utils.rollout_tree_gather_utils import (
|
| 6 |
+
gather_simulation_step_logs,
|
| 7 |
+
get_rollout_tree_paths,
|
| 8 |
+
)
|
| 9 |
+
from mllm.utils.stat_pack import StatPack
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_rollout_tree_stat_tally(
|
| 13 |
+
rollout_tree: RolloutTreeRootNode,
|
| 14 |
+
metrics: List[Callable[[SimulationStepLog], List[Tuple[str, float]]]],
|
| 15 |
+
) -> StatPack:
|
| 16 |
+
stat_tally = StatPack()
|
| 17 |
+
# get simulation step logs
|
| 18 |
+
node_list = get_rollout_tree_paths(rollout_tree)[0]
|
| 19 |
+
simulation_step_logs = gather_simulation_step_logs(node_list)
|
| 20 |
+
for simulation_step_log in simulation_step_logs:
|
| 21 |
+
for metric in metrics:
|
| 22 |
+
metric_result = metric(simulation_step_log)
|
| 23 |
+
if metric_result is not None:
|
| 24 |
+
for key, value in metric_result:
|
| 25 |
+
stat_tally.add_stat(key, value)
|
| 26 |
+
return stat_tally
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_rollout_tree_mean_stats(
|
| 30 |
+
rollout_tree: RolloutTreeRootNode, metrics: List[Callable[[SimulationStepLog], Any]]
|
| 31 |
+
) -> StatPack:
|
| 32 |
+
"""Get the mean stats for a rollout tree."""
|
| 33 |
+
stat_tally = get_rollout_tree_stat_tally(rollout_tree, metrics)
|
| 34 |
+
return stat_tally.mean()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get_mean_rollout_tree_stats(
|
| 38 |
+
rollout_trees: List[RolloutTreeRootNode],
|
| 39 |
+
metrics: List[Callable[[SimulationStepLog], Any]],
|
| 40 |
+
) -> StatPack:
|
| 41 |
+
"""Get the mean stats for a list of rollout trees."""
|
| 42 |
+
# TODO complete this
|
| 43 |
+
stat_tallies = [
|
| 44 |
+
get_rollout_tree_mean_stats(rollout_tree, metrics)
|
| 45 |
+
for rollout_tree in rollout_trees
|
| 46 |
+
]
|
| 47 |
+
mean_stat_tally = StatPack()
|
| 48 |
+
for stat_tally in stat_tallies:
|
| 49 |
+
mean_stat_tally.add_stats(stat_tally)
|
| 50 |
+
return mean_stat_tally.mean()
|
src_code_for_reproducibility/utils/short_id_gen.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import uuid
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def generate_short_id() -> int:
|
| 5 |
+
"""
|
| 6 |
+
Generates a short unique ID for tracking adapter versions.
|
| 7 |
+
|
| 8 |
+
Returns:
|
| 9 |
+
int: An 8-digit integer ID.
|
| 10 |
+
"""
|
| 11 |
+
return int(str(uuid.uuid4().int)[:8])
|
src_code_for_reproducibility/utils/stat_pack.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import pickle
|
| 5 |
+
from collections import Counter
|
| 6 |
+
from copy import deepcopy
|
| 7 |
+
from locale import strcoll
|
| 8 |
+
from statistics import mean
|
| 9 |
+
from typing import Any, Dict, Iterator, List, Optional, Tuple, TypedDict
|
| 10 |
+
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
import numpy as np
|
| 13 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 14 |
+
|
| 15 |
+
plt.style.use(
|
| 16 |
+
"https://raw.githubusercontent.com/dereckpiche/DedeStyle/refs/heads/main/dedestyle.mplstyle"
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
import wandb
|
| 20 |
+
|
| 21 |
+
from . import wandb_utils
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class StatPack:
|
| 25 |
+
def __init__(self):
|
| 26 |
+
self.data = {}
|
| 27 |
+
|
| 28 |
+
def add_stat(self, key: str, value: float | int | None):
|
| 29 |
+
assert (
|
| 30 |
+
isinstance(value, float) or isinstance(value, int) or value is None
|
| 31 |
+
), f"Value {value} is not a valid type"
|
| 32 |
+
if key not in self.data:
|
| 33 |
+
self.data[key] = []
|
| 34 |
+
self.data[key].append(value)
|
| 35 |
+
|
| 36 |
+
def add_stats(self, other: "StatPack"):
|
| 37 |
+
for key in other.keys():
|
| 38 |
+
self.add_stat(key, other[key])
|
| 39 |
+
|
| 40 |
+
def __getitem__(self, key: str):
|
| 41 |
+
return self.data[key]
|
| 42 |
+
|
| 43 |
+
def __setitem__(self, key: str, value: Any):
|
| 44 |
+
self.data[key] = value
|
| 45 |
+
|
| 46 |
+
def __contains__(self, key: str):
|
| 47 |
+
return key in self.data
|
| 48 |
+
|
| 49 |
+
def __len__(self):
|
| 50 |
+
return len(self.data)
|
| 51 |
+
|
| 52 |
+
def __iter__(self):
|
| 53 |
+
return iter(self.data)
|
| 54 |
+
|
| 55 |
+
def keys(self):
|
| 56 |
+
return self.data.keys()
|
| 57 |
+
|
| 58 |
+
def values(self):
|
| 59 |
+
return self.data.values()
|
| 60 |
+
|
| 61 |
+
def items(self):
|
| 62 |
+
return self.data.items()
|
| 63 |
+
|
| 64 |
+
def mean(self):
|
| 65 |
+
mean_st = StatPack()
|
| 66 |
+
for key in self.keys():
|
| 67 |
+
if isinstance(self[key], list):
|
| 68 |
+
# TODO: exclude None values
|
| 69 |
+
non_none_values = [v for v in self[key] if v is not None]
|
| 70 |
+
if non_none_values:
|
| 71 |
+
mean_st[key] = np.mean(np.array(non_none_values))
|
| 72 |
+
else:
|
| 73 |
+
mean_st[key] = None
|
| 74 |
+
return mean_st
|
| 75 |
+
|
| 76 |
+
def store_plots(self, folder: str):
|
| 77 |
+
os.makedirs(folder, exist_ok=True)
|
| 78 |
+
for key in self.keys():
|
| 79 |
+
plt.figure(figsize=(10, 5))
|
| 80 |
+
plt.plot(self[key])
|
| 81 |
+
plt.title(key)
|
| 82 |
+
plt.savefig(os.path.join(folder, f"{key}.pdf"))
|
| 83 |
+
plt.close()
|
| 84 |
+
|
| 85 |
+
def store_numpy(self, folder: str):
|
| 86 |
+
os.makedirs(folder, exist_ok=True)
|
| 87 |
+
for key in self.keys():
|
| 88 |
+
# Sanitize filename components (avoid slashes, spaces, etc.)
|
| 89 |
+
safe_key = str(key).replace(os.sep, "_").replace("/", "_").replace(" ", "_")
|
| 90 |
+
values = self[key]
|
| 91 |
+
# Convert None to NaN for numpy compatibility
|
| 92 |
+
arr = np.array(
|
| 93 |
+
[(np.nan if (v is None) else v) for v in values], dtype=float
|
| 94 |
+
)
|
| 95 |
+
np.save(os.path.join(folder, f"{safe_key}.npy"), arr)
|
| 96 |
+
|
| 97 |
+
def store_json(self, folder: str, filename: str = "stats.json"):
|
| 98 |
+
os.makedirs(folder, exist_ok=True)
|
| 99 |
+
with open(os.path.join(folder, filename), "w") as f:
|
| 100 |
+
json.dump(self.data, f, indent=4)
|
| 101 |
+
|
| 102 |
+
def store_csv(self, folder: str):
|
| 103 |
+
os.makedirs(folder, exist_ok=True)
|
| 104 |
+
for key in self.keys():
|
| 105 |
+
with open(os.path.join(folder, f"stats.csv"), "w") as f:
|
| 106 |
+
writer = csv.writer(f)
|
| 107 |
+
writer.writerow([key] + self[key])
|
| 108 |
+
|
| 109 |
+
def store_pickle(self, folder: str):
|
| 110 |
+
os.makedirs(folder, exist_ok=True)
|
| 111 |
+
for key in self.keys():
|
| 112 |
+
with open(os.path.join(folder, f"stats.pkl"), "wb") as f:
|
| 113 |
+
pickle.dump(self[key], f)
|
src_code_for_reproducibility/utils/update_start_epoch.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
# During run, set hydra.run.dir=./outputs/{folder}
|
| 4 |
+
def update_start_epoch(cfg, output_directory):
|
| 5 |
+
if cfg["experiment"]["resume_experiment"]:
|
| 6 |
+
folders = [f for f in os.listdir(output_directory) if f.startswith("iteration_")]
|
| 7 |
+
iterations = [int(f.split("_")[1]) for f in folders] if folders else [0]
|
| 8 |
+
cfg["experiment"]["start_epoch"] = max(iterations)
|
| 9 |
+
return None
|