| from header import * | |
| from .samplers import DistributedBatchSampler | |
| from .sft_dataset import * | |
| ''' | |
| def get_tokenizer(model): | |
| tokenizer = LlamaTokenizer.from_pretrained(model) | |
| tokenizer.bos_token_id, tokenizer.eos_token_id = 1, 2 | |
| tokenizer.pad_token = tokenizer.eos_token | |
| return tokenizer | |
| ''' | |
| def load_sft_dataset(args): | |
| ''' | |
| tokenizer = get_tokenizer(args['model_path']) | |
| dataset_name = args['models'][args['model']]['stage1_train_dataset'] # SupervisedDataset, str | |
| data_path = args["data_path"] | |
| data = globals()[dataset_name](data_path, tokenizer, args['max_length']) #SupervisedDataset | |
| ''' | |
| data = SupervisedDataset(args['data_path'], args['image_root_path']) | |
| sampler = torch.utils.data.RandomSampler(data) | |
| world_size = torch.distributed.get_world_size() | |
| rank = torch.distributed.get_rank() | |
| batch_size = args['world_size'] * args['dschf'].config['train_micro_batch_size_per_gpu'] | |
| batch_sampler = DistributedBatchSampler( | |
| sampler, | |
| batch_size, | |
| True, | |
| rank, | |
| world_size | |
| ) | |
| iter_ = DataLoader( | |
| data, | |
| batch_sampler=batch_sampler, | |
| num_workers=1, | |
| collate_fn=data.collate, | |
| pin_memory=True | |
| ) | |
| return data, iter_, sampler | |