import argparse
from pprint import pprint

import torch

from zoedepth.utils.easydict import EasyDict as edict
import torch.optim as optim
import torch.cuda.amp as amp
import torch.nn as nn
from zoedepth.utils.misc import count_parameters, parallelize
from zoedepth.trainers.builder import get_trainer
import torch.multiprocessing as mp
from zoedepth.data.data_mono import DepthDataLoader
from zoedepth.models.builder import build_model,build_router
from zoedepth.data.data_mono import MixedNYUKITTI
from zoedepth.utils.arg_utils import parse_unknown
from zoedepth.utils.config import change_dataset, get_config, ALL_EVAL_DATASETS, ALL_INDOOR, ALL_OUTDOOR
from zoedepth.trainers.loss import GradL1Loss, SILogLoss, FocalLossV1
import zoedepth.utils.logging as logging


def main(config):
    # load zoedepth_nk_generic
    depth_model = build_model(config)
    router = build_router(config)

    # load data
    train_loader = DepthDataLoader(config, 'train').data
    test_loader = DepthDataLoader(config, 'online_eval').data

    # train router
    depth_model = depth_model.cuda()
    router = router.cuda()

    # Training settings
    device = torch.device('cuda')
    criterion_d = SILogLoss()
    optimizer = optim.Adam(router.parameters(), 1e-4)

    global global_step
    global_step = 0

    # training
    for epoch in range(1, config.epochs + 1):
        print('\nEpoch: %03d - %03d' % (epoch, config.epochs))
        loss_train = validate_router(depth_model, router, train_loader, optimizer, criterion_d, epoch, device, config)
 
def validate_router(focal_model, router_model):

    return 



def infer(model, rgb):
    pass


if __name__ == '__main__':
    mp.set_start_method('forkserver')
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model", type=str,required=True, 
                        default='zoedepth_nk_generic',help="Name of the model to evaluate")
    parser.add_argument("-p", "--pretrained_resource", type=str,
                        required=True, default='local::/home/yss/桌面/code-master/Long-tail/weights/ZoeDepthNKgeneric_10-May_21-34-678642d67723_latest.pt', help="Pretrained resource to use for fetching weights. If not set, default resource from model config is used,  Refer models.model_io.load_state_from_resource for more details.")
    parser.add_argument("-d", "--dataset", type=str, required=False,
                        default='nyu', help="Dataset to evaluate on")
    parser.add_argument("-v","--version_name", type=str, required=False,
                        default='generic', help="version_name")
    parser.add_argument("--batch_size", type=int, required=False,
                        default='3', help="version_name")
    

    args, unknown_args = parser.parse_known_args()
    overwrite = {"pretrained_resource": args.pretrained_resource}
    config = get_config(args.model, "eval", args.dataset, **overwrite)
    # add new parameters
    config.batch_size = 12
    config.distributed = False
    config.epochs = 5
    config.model = 'zoedepth_nk_generic'
    config.pretrained_resource = 'local::/home/yss/桌面/code-master/Long-tail/weights/ZoeDepthNKgeneric_10-May_21-34-678642d67723_latest.pt'
    config.router_model = 'zoedepth'
    config.router_pretrained_resource = 'local::/home/yss/桌面/code-master/Long-tail/weights/ZoeDepthv1_09-May_13-44-e2e00279f3b7_best.pt'
    
    main(config)
    print('done')