import argparse
from pprint import pprint

import torch
from zoedepth.utils.easydict import EasyDict as edict
from tqdm import tqdm
import torch.nn as nn
from zoedepth.data.data_mono import DepthDataLoader
from zoedepth.data.sun_rgbd_loader import get_sunrgbd_loader
from zoedepth.models.builder import build_model
from zoedepth.utils.arg_utils import parse_unknown
from zoedepth.utils.config import change_dataset, get_config, ALL_EVAL_DATASETS, ALL_INDOOR, ALL_OUTDOOR
from zoedepth.utils.misc import (RunningAverageDict, colors, compute_metrics,compute_every_err,
                        count_parameters)
# python evaluate_tail.py -m zoedepth_nk_generic --pretrained_resource="local::/home/yss/桌面/code-master/Long-tail/weights/ZoeDepthNKgeneric_23-Apr_10-59-b58b1ebca21e_best.pt" -d nyu

@torch.no_grad()
def infer(model, images, **kwargs):
    """Inference with flip augmentation"""
    # images.shape = N, C, H, W
    def get_depth_from_prediction(pred):
        if isinstance(pred, torch.Tensor):
            pred = pred  # pass
        elif isinstance(pred, (list, tuple)):
            pred = pred[-1]
        elif isinstance(pred, dict):
            pred = pred#pred['metric_depth'] if 'metric_depth' in pred else pred['middle']
        else:
            raise NotImplementedError(f"Unknown output type {type(pred)}")
        return pred

    pred1 = model(images, **kwargs)
    pred1 = get_depth_from_prediction(pred1)

    #pred2 = model(torch.flip(images, [3]), **kwargs)
    #pred2 = get_depth_from_prediction(pred2)
    #pred2 = torch.flip(pred2, [3])

    #mean_pred = 0.5 * (pred1 + pred2)

    return pred1#mean_pred


@torch.no_grad()
def evaluate(model, test_loader, config, round_vals=True, round_precision=3):
    model.eval()
    metrics_list = [RunningAverageDict() for _ in range(10)]
    seg_test = False
    for indx, sample in tqdm(enumerate(test_loader), total=len(test_loader)):
        if 'has_valid_depth' in sample:
            if not sample['has_valid_depth']:
                continue
        image, depth = sample['image'], sample['depth']
        image, depth = image.cuda(), depth.cuda()
        o_img = image
        import pdb
        #pdb.set_trace()
        depth = depth.squeeze().unsqueeze(0).unsqueeze(0)
        if depth.max()<0.01:
            continue

        focal = sample.get('focal', torch.Tensor(
            [715.0873]).cuda())  # This magic number (focal) is only used for evaluating BTS model
        #"""
        if config.dataset == 'vkitti2' or config.dataset == 'kitti':# resize
            bs, _, h, w = image.shape
            assert w > h and bs == 1
            interval_all = w - 480 # 1216-352 = 864
            shift_size = 3
            interval = interval_all // (shift_size-1) # shift_size = 16  864//15=57
            sliding_images = []
            sliding_masks = torch.zeros((bs, 1, h, w), device=image.device) # 352x352
            for i in range(shift_size):
                sliding_images.append(image[..., :, i*interval:i*interval+480])
                sliding_masks[..., :, i*interval:i*interval+480] += 1
            image = torch.cat(sliding_images, dim=0)# 3x3x352x480 
        elif config.dataset == 'diml_outdoor':
            assert image.shape[0] == 1  # batch size为1
            assert image.shape[1] == 3  # RGB图像
            bs, _, h, w = image.shape
            assert h == 1080 and w == 1920
            target_height = 540
            target_width = 480
            num_blocks = 8
            interval_all_h = h - target_height  # 1080 - 540 = 540
            interval_all_w = w - target_width  # 1920 - 480 = 1440
            shift_size_h = 2
            shift_size_w = 4
            interval_h = interval_all_h // (shift_size_h - 1)  # 540 // 2 = 270
            interval_w = interval_all_w // (shift_size_w - 1)  # 1440 // 2 = 720
            sliding_images = []
            sliding_masks = torch.zeros((bs, 1, h, w), device=image.device)
            for i in range(shift_size_h):
                start_row = i * interval_h
                end_row = start_row + target_height
                for j in range(shift_size_w):
                    start_col = j * interval_w
                    end_col = start_col + target_width
                    sliding_images.append(image[..., start_row:end_row, start_col:end_col])
                    sliding_masks[..., start_row:end_row, start_col:end_col] += 1
            image = torch.cat(sliding_images, dim=0)
        #"""    
        pred = infer(model, image, dataset=sample['dataset'][0], focal=focal)#3,1,384,512
        #"""
        if config.dataset == 'vkitti2' or config.dataset == 'kitti':# resize
            # ['near', 'middle', 'wide', 'ultra', 'generic']
            # ['near', 'middle1', 'middle2', 'middle', 'wide1', 'wide2', 'wide', 'ultra', 'generic']
            #for focal_name in ['near', 'middle', 'wide', 'ultra', 'generic']:
            #for focal_name in ['near', 'middle1', 'middle', 'wide1', 'wide', 'ultra', 'generic']:
            for focal_name in ['near', 'middle1', 'middle2', 'middle', 'wide1', 'wide2', 'wide', 'ultra', 'generic']:
                preds = torch.zeros((bs, 1, h, w), device=depth.device)#1,1,352,1216
                #pdb.set_trace()
                for i in range(pred[focal_name].shape[0]):
                    slid = nn.functional.interpolate(pred[focal_name][i].squeeze().unsqueeze(0).unsqueeze(1), [352, 480], mode='bilinear', align_corners=True)
                    preds[..., :, i*interval:i*interval+480] += slid
                pred[focal_name] = preds/sliding_masks # 352x1216
        elif config.dataset == 'diml_outdoor':
            block_height = 540
            block_width = 480
            block_positions = [(0, 0), (0, 480), (0, 960), (0, 1440), (540, 0), (540, 480), (540, 960), (540, 1440)]
            # ['near', 'middle1', 'middle2', 'middle', 'wide1', 'wide2', 'wide', 'ultra', 'generic']
            for focal_name in ['near', 'middle', 'wide', 'ultra', 'generic']:
                preds = torch.zeros((bs, 1, h, w), device=depth.device)#1,1,352,1216
                for i in range(pred[focal_name].shape[0]):
                    y_offset, x_offset = block_positions[i]
                    slid = nn.functional.interpolate(pred[focal_name][i].squeeze().unsqueeze(0).unsqueeze(1), [540, 480], mode='bilinear', align_corners=True)
                    preds[..., y_offset:y_offset+block_height, x_offset:x_offset+block_width] = slid
                pred[focal_name] = preds/sliding_masks # 352x121

        pred['near'] = nn.functional.interpolate(
            pred['near'], depth.shape[-2:], mode='bilinear', align_corners=True)
        
        """pred['middle1'] = nn.functional.interpolate(
            pred['middle1'], depth.shape[-2:], mode='bilinear', align_corners=True)
        pred['middle2'] = nn.functional.interpolate(
            pred['middle2'], depth.shape[-2:], mode='bilinear', align_corners=True)"""
        pred['middle'] = nn.functional.interpolate(
            pred['middle'], depth.shape[-2:], mode='bilinear', align_corners=True)
        """pred['wide1'] = nn.functional.interpolate(
            pred['wide1'], depth.shape[-2:], mode='bilinear', align_corners=True)
        
        pred['wide2'] = nn.functional.interpolate(
            pred['wide2'], depth.shape[-2:], mode='bilinear', align_corners=True)"""
        pred['wide'] = nn.functional.interpolate(
            pred['wide'], depth.shape[-2:], mode='bilinear', align_corners=True)
        pred['ultra'] = nn.functional.interpolate(
            pred['ultra'], depth.shape[-2:], mode='bilinear', align_corners=True)
        pred['generic'] = nn.functional.interpolate(
            pred['generic'], depth.shape[-2:], mode='bilinear', align_corners=True)
        
        # post-process: combination
        focal_depth = pred['near']#pred['near']
        router = depth#pred['generic']#pred['generic']#pred['generic']#depth#pred['generic']#pred['generic']#pred['generic']#pred['generic'] * 0.5 + pred['near'] * 0.5#pred['generic']#depth#
        #soft_err = [0.15, 0.2, 0.2]
        #soft_err = [0.15, 0.2, 0.5, 0.5, 0.5]
        soft_err = [0.35, 0.35, 0.35, 0.25, 0.15, 0.15, 0.15]

        #near, middle, wide = 3., 8., 25
        #near, middle1, middle, wide1, wide = 3., 8., 15, 25, 35
        #near, middle1, middle2, middle, wide1, wide2, wide = 3., 6, 8., 15, 20, 30, 40
        near, middle, wide = 1., 3.5, 5.5


        focal_depth[router >= near + soft_err[0]] = pred['middle'][router >= near + soft_err[0]]
        focal_depth[router >= middle + soft_err[1]] = pred['wide'][router >= middle + soft_err[1]]
        focal_depth[router >= wide + soft_err[2]] = pred['ultra'][router >= wide + soft_err[2]]
        #focal_depth[router >= near + soft_err[0]] = pred['middle1'][router >= near + soft_err[0]]
        #focal_depth[router >= middle1 + soft_err[1]] = pred['middle2'][router >= middle1 + soft_err[1]]
        #focal_depth[router >= middle2 + soft_err[2]] = pred['middle'][router >= middle2 + soft_err[2]]
        #focal_depth[router >= middle + soft_err[3]] = pred['wide1'][router >= middle + soft_err[3]]
        #focal_depth[router >= wide1 + soft_err[4]] = pred['wide2'][router >= wide1 + soft_err[4]]
        #focal_depth[router >= wide2 + soft_err[5]] = pred['wide'][router >= wide2 + soft_err[5]]
        #focal_depth[router >= wide + soft_err[6]] = pred['ultra'][router >= wide + soft_err[6]]
        #focal_depth[router >= far + soft_err[2]] = pred['ultra'][router >= far + soft_err[2]]
        # smothing

        smothing = True
        if smothing:
            smothing_nm = torch.logical_and(router<near + soft_err[0], router > near - soft_err[0])
            smothing_mw = torch.logical_and(router<middle + soft_err[1], router > middle - soft_err[1])
            smothing_wu = torch.logical_and(router<wide + soft_err[2], router > wide - soft_err[2])
            #smothing_wu = torch.logical_and(router<middle2 + soft_err[2], router > middle2 - soft_err[2])
            focal_depth[smothing_nm] = (pred['near'][smothing_nm] + pred['middle'][smothing_nm])/2.0 # generic[smothing_nm]#
            focal_depth[smothing_mw] = (pred['middle'][smothing_mw] + pred['wide'][smothing_mw])/2.0 # generic[smothing_mw]#
            focal_depth[smothing_wu] = (pred['wide'][smothing_wu] + pred['ultra'][smothing_wu])/2.0 # generic[smothing_wu]#

        if dataset == 'ibims' or dataset == 'nyu' or dataset == 'diode_indoor':
            max_depth_eval = 10.0
        elif dataset == 'kitti' or dataset == 'vkitti2' or dataset == 'diode_outdoor' or dataset == 'diml_outdoor':
            max_depth_eval = 80.0
        """
        r = focal_depth
        blank = focal_depth < 0.01
        near_r =  torch.logical_and(focal_depth > 0.01, focal_depth <= 3)
        middle_r =  torch.logical_and(3 < focal_depth, focal_depth <= 7)
        wide_r = torch.logical_and(7 < focal_depth, focal_depth<= 10)
        ultra = focal_depth > 25
        r[blank] = 0
        r[near_r] = 20.
        r[middle_r] = 40.
        r[wide_r] = 60.
        r[ultra] = 80."""
        #r = pred['generic']
        Ms = compute_every_err(depth, focal_depth, seg_test=seg_test, min_depth_eval=0.1, max_depth_eval=max_depth_eval,config=config) 
        #Ms = compute_every_err(depth, pred['ultra'], config=config)
        #"""
        # Save image, depth, pred for visualization
        if "save_images" in config and config.save_images:
        #if True:
            import os
            # print("Saving images ...")
            from PIL import Image
            import torchvision.transforms as transforms
            from zoedepth.utils.misc import colorize

            os.makedirs("visualization/", exist_ok=True)
            config.save_images = "visualization/"
            # def save_image(img, path):
            max_depth = 80.
            d = colorize(depth.squeeze().cpu().numpy(), 0, max_depth)
            r = colorize(r.squeeze().cpu().numpy(), 0, max_depth)
            p = colorize(focal_depth.squeeze().cpu().numpy(), 0, max_depth)
            p_near = colorize(pred['near'].squeeze().cpu().numpy(), 0, max_depth)
            p_middle = colorize(pred['middle'].squeeze().cpu().numpy(), 0, max_depth)
            p_wide = colorize(pred['wide'].squeeze().cpu().numpy(), 0, max_depth)
            p_ultra = colorize(pred['ultra'].squeeze().cpu().numpy(), 0, max_depth)
            im = transforms.ToPILImage()(o_img.squeeze().cpu())
            im.save(os.path.join(config.save_images, f"{indx}_img.png"))
            Image.fromarray(d).save(os.path.join(config.save_images, f"{indx}_depth.png"))
            Image.fromarray(r).save(os.path.join(config.save_images, f"{indx}_router.png"))
            Image.fromarray(p).save(os.path.join(config.save_images, f"{indx}_pred.png"))
            Image.fromarray(p_near).save(os.path.join(config.save_images, f"{indx}_near.png"))
            Image.fromarray(p_middle).save(os.path.join(config.save_images, f"{indx}_middle.png"))
            Image.fromarray(p_wide).save(os.path.join(config.save_images, f"{indx}_wide.png"))
            Image.fromarray(p_ultra).save(os.path.join(config.save_images, f"{indx}_ultra.png"))
        # print(depth.shape, pred.shape)
        #far
        if seg_test:
            for i in range(len(Ms)):
                if Ms[i] != 0:
                    metrics_list[i].update(Ms[i])
        else:
            metrics_list[0].update(Ms)

    if round_vals:
        def r(m): return round(m, round_precision)
    else:
        def r(m): return m
    if seg_test:
        for i in range(len(metrics_list)):   
            if metrics_list[i].get_value() is not None:
                metrics_list[i] = {k: r(v) for k, v in metrics_list[i].get_value().items()}
    else:
        metrics_list[0] = {k: r(v) for k, v in metrics_list[0].get_value().items()}
    print(metrics_list)
    return metrics_list

def main(config):
    model = build_model(config)
    test_loader = DepthDataLoader(config, 'online_eval').data
    model = model.cuda()
    metrics = evaluate(model, test_loader, config)
    """
    print(f"{colors.fg.green}")
    print(metrics['global'])
    print(metrics['head'])
    print(metrics['tail'])
    print(f"{colors.reset}")"""
    #metrics['#params'] = f"{round(count_parameters(model, include_all=True)/1e6, 2)}M"
    return metrics


def eval_model(model_name, pretrained_resource, dataset='nyu', **kwargs):
    # Load default pretrained resource defined in config if not set
    overwrite = {**kwargs, "pretrained_resource": pretrained_resource} if pretrained_resource else kwargs
    config = get_config(model_name, "eval", dataset, **overwrite)
    # config = change_dataset(config, dataset)  # change the dataset
    pprint(config)
    print(f"Evaluating {model_name} on {dataset}...")
    metrics = main(config)
    return metrics


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model", type=str,
                        required=True, help="Name of the model to evaluate")
    parser.add_argument("-p", "--pretrained_resource", type=str,
                        required=False, default=None, help="Pretrained resource to use for fetching weights. If not set, default resource from model config is used,  Refer models.model_io.load_state_from_resource for more details.")
    parser.add_argument("-d", "--dataset", type=str, required=False,
                        default='nyu', help="Dataset to evaluate on")

    args, unknown_args = parser.parse_known_args()
    overwrite_kwargs = parse_unknown(unknown_args)

    if "ALL_INDOOR" in args.dataset:
        datasets = ALL_INDOOR
    elif "ALL_OUTDOOR" in args.dataset:
        datasets = ALL_OUTDOOR
    elif "ALL" in args.dataset:
        datasets = ALL_EVAL_DATASETS
    elif "," in args.dataset:
        datasets = args.dataset.split(",")
    else:
        datasets = [args.dataset]
    
    for dataset in datasets:
        eval_model(args.model, pretrained_resource=args.pretrained_resource,
                    dataset=dataset, **overwrite_kwargs)