import numpy as np from datasets import load_dataset import os import config os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "true" # 加载数据集 dataset = load_dataset(config.PROCESS_REPO_ID, split="train") def process_audio(audio_obj): """处理音频对象,返回音频数据和采样率""" try: if hasattr(audio_obj, 'get_all_samples'): samples = audio_obj.get_all_samples() audio_data = samples.data if not isinstance(audio_data, np.ndarray): audio_data = np.array(audio_data, dtype=np.float32) sample_rate = samples.sample_rate if not isinstance(sample_rate, int): sample_rate = int(sample_rate) if len(audio_data.shape) > 1: audio_data = audio_data.mean(axis=0) return audio_data, sample_rate else: print("音频对象缺少 get_all_samples 方法") return None, None except Exception as e: print(f"处理音频失败: {e}") return None, None def load_tasks(): """预处理所有任务""" print("处理数据集...") tasks = [] for i, row in enumerate(dataset): audioA_data, audioA_rate = process_audio(row[config.FIELD_AUDIO_A]) audioB_data, audioB_rate = process_audio(row[config.FIELD_AUDIO_B]) if (audioA_data is not None and audioB_data is not None and isinstance(audioA_data, np.ndarray) and isinstance(audioB_data, np.ndarray) and isinstance(audioA_rate, int) and isinstance(audioB_rate, int)): tasks.append({ "instruction": config.INSTRUCTION, "text": row[config.FIELD_TEXT], "audioA": (audioA_data, audioA_rate), "audioB": (audioB_data, audioB_rate) }) else: print(f"跳过任务 {i}:无效的音频数据") print(f"成功处理 {len(tasks)} 个任务") if len(tasks) == 0: print("没有可用任务!") exit() return tasks