import torch
import torchvision.transforms as transforms
from torchvision.utils import save_image
import cv2
import os
import math
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

# GPU 检测
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def lin_log(x, threshold=20):
    """
    linear mapping + logarithmic mapping.
    :param x: float or ndarray the input linear value in range 0-255
    :param threshold: float threshold 0-255 the threshold for transition from linear to log mapping
    """
    if x.dtype is not torch.float64:
        x = x.double()
    f = (1.0 / threshold) * math.log(threshold)
    y = torch.where(x <= threshold, x * f, torch.log(x))
    return y.float()

def compute_event_frame(frame1, frame2, rgb2gray):
    """
    Compute the event frame between two consecutive frames.
    :param frame1: Torch tensor of the first frame (shape: C x H x W)
    :param frame2: Torch tensor of the second frame (shape: C x H x W)
    :param rgb2gray: Torch tensor for RGB to grayscale conversion (shape: C)
    :return: Event frame (shape: H x W)
    """
    rgb2gray = rgb2gray.view(3, 1, 1).to(frame1.device)  # 调整形状以适配广播
    frame1_gray = torch.sum(frame1 * rgb2gray, dim=0)
    frame2_gray = torch.sum(frame2 * rgb2gray, dim=0)

    thres_pos = (lin_log(frame2_gray * 255) - lin_log(frame1_gray * 255)) / 0.2
    thres_neg = (lin_log(frame2_gray * 255) - lin_log(frame1_gray * 255)) / 0.2
    thres_pos = thres_pos.to(torch.int32)
    thres_neg = thres_neg.to(torch.int32)

    event_frame = torch.zeros_like(frame1_gray, dtype=torch.int32, device=frame1.device)
    event_frame[thres_pos > 0] = thres_pos[thres_pos > 0]
    event_frame[thres_neg < 0] = thres_neg[thres_neg < 0]

    return event_frame, thres_pos, thres_neg

def generate_grayscale_event(event_frame, out):
    """
    Generate and save a grayscale event frame.
    :param event_frame: Tensor representing the event frame
    :param out: VideoWriter object for grayscale event video
    """
    # Normalize event frame to 0-255 for saving
    event_frame_normalized = ((event_frame - event_frame.min()) / (event_frame.max() - event_frame.min()) * 255).to(torch.uint8)

    # Convert to numpy and write to video
    event_frame_numpy = event_frame_normalized.cpu().numpy().astype(np.uint8)
    out.write(event_frame_numpy)

def generate_rgb_event(event_frame, out_rgb):
    """
    Generate and save an RGB event frame.
    :param event_frame: Tensor representing the event frame
    :param out_rgb: VideoWriter object for RGB event video
    """
    # Generate RGB event frame
    event_frame_rgb = torch.zeros((3, event_frame.shape[0], event_frame.shape[1]), dtype=torch.uint8, device=device)
    event_frame_rgb[0][event_frame > 0] = 255
    event_frame_rgb[2][event_frame < 0] = 255

    # Create a mask for black regions (all channels are 0)
    black_mask = (event_frame_rgb[0] == 0) & (event_frame_rgb[1] == 0) & (event_frame_rgb[2] == 0)
    # Set black regions to white (255 for all channels)
    event_frame_rgb[:, black_mask] = 255

    # Convert RGB event frame to numpy and write to RGB video
    event_frame_rgb_numpy = event_frame_rgb.permute(1, 2, 0).cpu().numpy()
    out_rgb.write(event_frame_rgb_numpy)
