import argparse
from pprint import pprint

import torch
from zoedepth.utils.easydict import EasyDict as edict
from tqdm import tqdm
import torch.nn as nn
from zoedepth.data.data_mono import DepthDataLoader
from zoedepth.models.builder import build_model
from zoedepth.utils.arg_utils import parse_unknown
from zoedepth.utils.config import change_dataset, get_config, ALL_EVAL_DATASETS, ALL_INDOOR, ALL_OUTDOOR
from zoedepth.utils.misc import (RunningAverageDict, colors, compute_metrics,compute_every_err,
                        count_parameters)
# python evaluate_tail.py -m zoedepth_nk_generic --pretrained_resource="local::/home/yss/桌面/code-master/Long-tail/weights/ZoeDepthNKgeneric_23-Apr_10-59-b58b1ebca21e_best.pt" -d nyu

@torch.no_grad()
def infer(model, images, **kwargs):
    """Inference with flip augmentation"""
    # images.shape = N, C, H, W
    def get_depth_from_prediction(pred):
        if isinstance(pred, torch.Tensor):
            pred = pred  # pass
        elif isinstance(pred, (list, tuple)):
            pred = pred[-1]
        elif isinstance(pred, dict):
            pred = pred#pred['metric_depth'] if 'metric_depth' in pred else pred['middle']
        else:
            raise NotImplementedError(f"Unknown output type {type(pred)}")
        return pred

    pred1 = model(images, **kwargs)
    pred1 = get_depth_from_prediction(pred1)

    #pred2 = model(torch.flip(images, [3]), **kwargs)
    #pred2 = get_depth_from_prediction(pred2)
    #pred2 = torch.flip(pred2, [3])

    #mean_pred = 0.5 * (pred1 + pred2)

    return pred1#mean_pred


@torch.no_grad()
def evaluate(model, test_loader, config, round_vals=True, round_precision=3):
    model.eval()
    metrics_list = [RunningAverageDict() for _ in range(10)]
    for indx, sample in tqdm(enumerate(test_loader), total=len(test_loader)):
        if 'has_valid_depth' in sample:
            if not sample['has_valid_depth']:
                continue
        image, depth = sample['image'], sample['depth']
        image, depth = image.cuda(), depth.cuda()
        o_img = image
        import pdb
        #pdb.set_trace()
        depth = depth.squeeze().unsqueeze(0).unsqueeze(0)
        focal = sample.get('focal', torch.Tensor(
            [715.0873]).cuda())  # This magic number (focal) is only used for evaluating BTS model
        #"""
        if config.dataset == 'kitti':# resize
            bs, _, h, w = image.shape
            assert w > h and bs == 1
            interval_all = w - 480 # 1216-352 = 864
            shift_size = 3
            interval = interval_all // (shift_size-1) # shift_size = 16  864//15=57
            sliding_images = []
            sliding_masks = torch.zeros((bs, 1, h, w), device=image.device) # 352x352
            for i in range(shift_size):
                sliding_images.append(image[..., :, i*interval:i*interval+480])
                sliding_masks[..., :, i*interval:i*interval+480] += 1
            image = torch.cat(sliding_images, dim=0)# 3x3x352x480          
        #"""    
        pred = infer(model, image, dataset=sample['dataset'][0], focal=focal)#3,1,384,512
        #"""
        if config.dataset == 'kitti':# resize
            for focal_name in ['near', 'middle', 'wide', 'ultra']:
                preds = torch.zeros((bs, 1, h, w), device=depth.device)#1,1,352,1216
                #pdb.set_trace()
                for i in range(pred[focal_name].shape[0]):
                    slid = nn.functional.interpolate(pred[focal_name][i].squeeze().unsqueeze(0).unsqueeze(1), [352, 480], mode='bilinear', align_corners=True)
                    preds[..., :, i*interval:i*interval+480] += slid
                pred[focal_name] = preds/sliding_masks # 352x1216

        
        pred['near'] = nn.functional.interpolate(
            pred['near'], depth.shape[-2:], mode='bilinear', align_corners=True)
        pred['middle'] = nn.functional.interpolate(
            pred['middle'], depth.shape[-2:], mode='bilinear', align_corners=True)
        pred['wide'] = nn.functional.interpolate(
            pred['wide'], depth.shape[-2:], mode='bilinear', align_corners=True)
        pred['ultra'] = nn.functional.interpolate(
            pred['ultra'], depth.shape[-2:], mode='bilinear', align_corners=True)
        
        # post-process: combination
        focal_depth = pred['near']
        focal_depth[depth>=3.25] = pred['middle'][depth>=3.25]
        focal_depth[depth>=8.25] = pred['wide'][depth>=8.25]
        focal_depth[depth>=25.25] = pred['ultra'][depth>=25.25]
        # smothing
        smothing_nm = torch.logical_and(depth<3.25, depth > 2.75)
        smothing_mw = torch.logical_and(depth<8.25, depth > 7.75)
        smothing_wu = torch.logical_and(depth<25.25, depth > 24.75)
        focal_depth[smothing_nm] = (pred['near'][smothing_nm] + pred['middle'][smothing_nm])/2.0
        focal_depth[smothing_mw] = (pred['middle'][smothing_mw] + pred['wide'][smothing_mw])/2.0
        focal_depth[smothing_wu] = (pred['wide'][smothing_wu] + pred['ultra'][smothing_wu])/2.0

        Ms = compute_every_err(depth, focal_depth, config=config) 
        #Ms = compute_every_err(depth, pred['ultra'], config=config)
        #"""
        # Save image, depth, pred for visualization
        if "save_images" in config and config.save_images:
            import os
            # print("Saving images ...")
            from PIL import Image
            import torchvision.transforms as transforms
            from zoedepth.utils.misc import colorize

            os.makedirs(config.save_images, exist_ok=True)
            # def save_image(img, path):
            d = colorize(depth.squeeze().cpu().numpy(), 0, 10)
            p = colorize(focal_depth.squeeze().cpu().numpy(), 0, 10)
            im = transforms.ToPILImage()(o_img.squeeze().cpu())
            im.save(os.path.join(config.save_images, f"{indx}_img.png"))
            Image.fromarray(d).save(os.path.join(config.save_images, f"{indx}_depth.png"))
            Image.fromarray(p).save(os.path.join(config.save_images, f"{indx}_pred.png"))
        # print(depth.shape, pred.shape)
        #far
        """
        for i in range(len(Ms)):
            if Ms[i] != 0:
                metrics_list[i].update(Ms[i])
        """
        metrics_list[0].update(Ms)



    if round_vals:
        def r(m): return round(m, round_precision)
    else:
        def r(m): return m
    """
    for i in range(len(Ms)):    
        metrics_list[i] = {k: r(v) for k, v in metrics_list[i].get_value().items()}
    #"""
    metrics_list[0] = {k: r(v) for k, v in metrics_list[0].get_value().items()}
    print(metrics_list)
    return metrics_list

def main(config):
    model = build_model(config)
    test_loader = DepthDataLoader(config, 'online_eval').data
    model = model.cuda()
    metrics = evaluate(model, test_loader, config)
    print(f"{colors.fg.green}")
    print(metrics['global'])
    print(metrics['head'])
    print(metrics['tail'])
    print(f"{colors.reset}")
    metrics['#params'] = f"{round(count_parameters(model, include_all=True)/1e6, 2)}M"
    return metrics


def eval_model(model_name, pretrained_resource, dataset='nyu', **kwargs):

    # Load default pretrained resource defined in config if not set
    overwrite = {**kwargs, "pretrained_resource": pretrained_resource} if pretrained_resource else kwargs
    config = get_config(model_name, "eval", dataset, **overwrite)
    # config = change_dataset(config, dataset)  # change the dataset
    pprint(config)
    print(f"Evaluating {model_name} on {dataset}...")
    metrics = main(config)
    return metrics


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model", type=str,
                        required=True, help="Name of the model to evaluate")
    parser.add_argument("-p", "--pretrained_resource", type=str,
                        required=False, default=None, help="Pretrained resource to use for fetching weights. If not set, default resource from model config is used,  Refer models.model_io.load_state_from_resource for more details.")
    parser.add_argument("-d", "--dataset", type=str, required=False,
                        default='nyu', help="Dataset to evaluate on")

    args, unknown_args = parser.parse_known_args()
    overwrite_kwargs = parse_unknown(unknown_args)

    if "ALL_INDOOR" in args.dataset:
        datasets = ALL_INDOOR
    elif "ALL_OUTDOOR" in args.dataset:
        datasets = ALL_OUTDOOR
    elif "ALL" in args.dataset:
        datasets = ALL_EVAL_DATASETS
    elif "," in args.dataset:
        datasets = args.dataset.split(",")
    else:
        datasets = [args.dataset]
    
    for dataset in datasets:
        eval_model(args.model, pretrained_resource=args.pretrained_resource,
                    dataset=dataset, **overwrite_kwargs)