import json
import cv2
import torch
import torchvision.transforms as transforms
import torch.multiprocessing as mp
from tqdm import tqdm
from multiprocessing import Process, Manager
import argparse
import sys, os
import time
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.utils import save_json_entry, convert_json_line_to_general, find_optimal_thread_count, convert_to_mp4, load_json_any, save_json_any
from event_utils import compute_event_frame, generate_grayscale_event, generate_rgb_event

def parse_config():
    """
    Returns:
        dict: A dictionary containing the following keys:
            - input_video_root (str)
            - input_json_path (str)
            - output_video_root (str)
            - output_json_path (str)
            - max_threads (int)
            - thread_threshold (int)
    """
    parser = argparse.ArgumentParser(
        description="Multi-process Canny video processing script"
    )
    parser.add_argument(
        "--input_video_root",
        type=str,
        default="HDV_dataset/HDV_original",
        help="Root directory of input videos (default: HDV_dataset/HDV_original)",
    )
    parser.add_argument(
        "--input_json_path",
        type=str,
        default="HDV_dataset/info_input.json",
        help="Path to input JSON file (default: HDV_dataset/info_input.json)",
    )
    parser.add_argument(
        "--output_video_root",
        type=str,
        default="HDV_dataset/HDV_event",
        help="Root directory for output videos (default: HDV_dataset/HDV_event)",
    )
    parser.add_argument(
        "--output_json_path",
        type=str,
        default="HDV_dataset/info_output.json",
        help="Path to output JSON file (default: HDV_dataset/info_output.json)",
    )
    parser.add_argument(
        "--max_threads",
        type=int,
        default=8,
        help="Maximum number of processes (default: 8)",
    )
    parser.add_argument(
        "--thread_threshold",
        type=int,
        default=8,
        help="Maximum number of single process (default: 8)",
    )

    args = parser.parse_args()

    config = {
        "input_video_root": args.input_video_root,
        "input_json_path": args.input_json_path,
        "output_video_root": args.output_video_root,
        "output_json_path": args.output_json_path,
        "max_threads": args.max_threads,
        "thread_threshold": args.thread_threshold
    }
    return config

def event_process_video_gray(video_path, output_path, device):
    """
    Generate a grayscale event video from the input video.
    :param video_path: Path to the input video file
    :param output_path: Path to save the output grayscale event video
    :param device: Torch device
    """
    rgb2gray = torch.tensor([0.2989, 0.5870, 0.1140]).to(device)
    # Open the video file
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise ValueError(f"Unable to open video file: {video_path}")
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)

    # Define the codec and create VideoWriter object
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out_gray = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height), isColor=False)
    # Transform to convert frames to tensors
    # transform = transforms.ToTensor()
    # Read the first frame
    ret, prev_frame = cap.read()
    if not ret:
        raise ValueError("Failed to read the first frame from the video.")
    prev_frame_tensor = torch.from_numpy(prev_frame).permute(2, 0, 1).to(torch.float32).to(device)
    while True:
        # Read the next frame
        ret, curr_frame = cap.read()
        if not ret:
            break

        curr_frame_tensor = transform(curr_frame).to(torch.float32).to(device)
        event_frame, thres_pos, thres_neg = compute_event_frame(prev_frame_tensor, curr_frame_tensor, rgb2gray)
        generate_grayscale_event(event_frame, out_gray)
        prev_frame_tensor = curr_frame_tensor

    # Release resources
    cap.release()
    out_gray.release()
    return True

def event_process_video_rgb(video_path, output_path, device):
    """
    Generate an RGB event video from the input video.
    :param video_path: Path to the input video file
    :param output_path: Path to save the output RGB event video
    :param device: Torch device
    """
    rgb2gray = torch.tensor([0.2989, 0.5870, 0.1140]).to(device)
    # Open the video file
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise ValueError(f"Unable to open video file: {video_path}")
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)

    # Define the codec and create VideoWriter object
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out_rgb = cv2.VideoWriter(output_path.replace('.mp4', '_rgb.mp4'), fourcc, fps, (frame_width, frame_height), isColor=True)

    # Transform to convert frames to tensors
    transform = transforms.ToTensor()

    # Read the first frame
    ret, prev_frame = cap.read()
    if not ret:
        raise ValueError("Failed to read the first frame from the video.")

    prev_frame_tensor = transform(prev_frame).to(torch.float32).to(device)
    # prev_frame_tensor = torch.from_numpy(prev_frame).permute(2, 0, 1).to(torch.float32).to(device)

    while True:
        # Read the next frame
        ret, curr_frame = cap.read()
        if not ret:
            break

        curr_frame_tensor = transform(curr_frame).to(torch.float32).to(device)

        # Compute the event frame
        event_frame, thres_pos, thres_neg = compute_event_frame(prev_frame_tensor, curr_frame_tensor, rgb2gray)

        # Generate RGB event frame and write to video
        generate_rgb_event(event_frame, out_rgb)

        # Update the previous frame
        prev_frame_tensor = curr_frame_tensor

    cap.release()
    out_rgb.release()
    return True


def process_videos_in_batch(batch, input_video_root, output_video_root, output_data, gpu_id):
    """
    处理一批视频，并显式指定 GPU。
    """
    device = f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu"
    for item in tqdm(batch, desc=f"video to event on GPU {gpu_id}"):
        rel_path = item.get("video_clip_path", "")

        input_path = os.path.join(input_video_root, rel_path)
        base_name = os.path.splitext(os.path.basename(rel_path))[0]
        out_name = f"{base_name}_event.mp4"
        output_path = os.path.join(output_video_root, out_name)

        try:
            success = event_process_video_rgb(input_path, output_path, device)
            if success:
                item["video_event_path"] = output_path
                save_json_entry(item, output_data)
            else:
                print(f"Failed to process video: {input_path}")
        except Exception as e:
            print(f"Error processing video {input_path} on GPU {gpu_id}: {e}")



def main():
    # 1. Parse configuration
    config = parse_config()
    input_video_root = config["input_video_root"]
    input_json_path = config["input_json_path"]
    output_video_root = config["output_video_root"]
    output_json_path = config["output_json_path"]
    max_threads = config["max_threads"]
    thread_threshold = config["thread_threshold"]

    # 检测可用 GPU 数量
    gpu_count = torch.cuda.device_count()
    if gpu_count == 0:
        raise RuntimeError("No GPUs available!")
    print(f"Detected {gpu_count} GPUs.")

    # 2. Load input JSON, output JSON, and detect each one's format
    input_data, input_is_line_based = load_json_any(input_json_path)
    output_data, output_is_line_based = load_json_any(output_json_path)

    processed_ids = {item["id"] for item in output_data if "id" in item}
    to_process_list = [item for item in input_data if item.get("id") not in processed_ids]

    total_input = len(input_data)
    total_output = len(output_data)
    need_to_process = len(to_process_list)

    print(f"Total samples in input JSON: {total_input}")
    print(f"Total samples in output JSON: {total_output}")
    print(f"Number of samples to process: {need_to_process}")

    if need_to_process == 0:
        print("No new video to process. Exiting.")
        return

    threads_to_use = find_optimal_thread_count(need_to_process, max_threads, thread_threshold)
    print(f"Number of processes to use: {threads_to_use}")

    os.makedirs(output_video_root, exist_ok=True)

    # 3. Split the tasks and assign to GPUs
    batch_size = (need_to_process + threads_to_use - 1) // threads_to_use
    batches = [to_process_list[i:i + batch_size] for i in range(0, need_to_process, batch_size)]

    processes = []
    start_time = time.time()

    # 分配任务到每张 GPU
    for i, batch in enumerate(batches):
        gpu_id = i % gpu_count  # 循环分配 GPU
        p = Process(target=process_videos_in_batch, args=(batch, input_video_root, output_video_root, output_json_path, gpu_id))
        processes.append(p)
        p.start()

    for p in processes:
        p.join()

    # Save the updated output data
    convert_json_line_to_general(output_json_path)
    end_time = time.time()
    print("video to event time:", end_time - start_time)



if __name__ == "__main__":
    mp.set_start_method('spawn', force=True)

    # 如果需要，可以取消下面的硬编码参数，使用命令行传参
    sys.argv = ['/home/yexin/data_processing/HDV_Data_Processing/video2event/v2event.py',
                '--input_video_root', '/home/yexin/data_processing/HDV_dataset/HDV_clip',
                '--input_json_path', '/home/yexin/data_processing/HDV_dataset/data_json/test.json',
                '--output_video_root', '/home/yexin/data_processing/HDV_dataset/HDV_event',
                '--output_json_path', '/home/yexin/data_processing/HDV_dataset/data_json/test_event.json',
                '--max_threads', '4',
                '--thread_threshold', '2']
    # main()

    main()

'''
示例命令行调用：

python /home/yexin/data_processing/HDV_Data_Processing/video2canny/v2canny.py \
    --input_video_root /home/yexin/data_processing/HDV_dataset/HDV_clip \
    --input_json_path /home/yexin/data_processing/HDV_dataset/data_json/test.json \
    --output_video_root /home/yexin/data_processing/HDV_dataset/HDV_event \
    --output_json_path /home/yexin/data_processing/HDV_dataset/data_json/test_event.json \
    --max_threads 4 \
    --thread_threshold 2
'''
