DataLabelingApp / data_processing.py
sunnyzjx's picture
Update data_processing.py
a7eb676 verified
import numpy as np
from datasets import load_dataset
import os
import config
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "true"
# 加载数据集
dataset = load_dataset(config.PROCESS_REPO_ID, split="train")
def process_audio(audio_obj):
"""处理音频对象,返回音频数据和采样率"""
try:
if hasattr(audio_obj, 'get_all_samples'):
samples = audio_obj.get_all_samples()
audio_data = samples.data
if not isinstance(audio_data, np.ndarray):
audio_data = np.array(audio_data, dtype=np.float32)
sample_rate = samples.sample_rate
if not isinstance(sample_rate, int):
sample_rate = int(sample_rate)
if len(audio_data.shape) > 1:
audio_data = audio_data.mean(axis=0)
return audio_data, sample_rate
else:
print("音频对象缺少 get_all_samples 方法")
return None, None
except Exception as e:
print(f"处理音频失败: {e}")
return None, None
def load_tasks():
"""预处理所有任务"""
print("处理数据集...")
tasks = []
for i, row in enumerate(dataset):
audioA_data, audioA_rate = process_audio(row[config.FIELD_AUDIO_A])
audioB_data, audioB_rate = process_audio(row[config.FIELD_AUDIO_B])
if (audioA_data is not None and audioB_data is not None and
isinstance(audioA_data, np.ndarray) and isinstance(audioB_data, np.ndarray) and
isinstance(audioA_rate, int) and isinstance(audioB_rate, int)):
tasks.append({
"instruction": config.INSTRUCTION,
"text": row[config.FIELD_TEXT],
"audioA": (audioA_data, audioA_rate),
"audioB": (audioB_data, audioB_rate)
})
else:
print(f"跳过任务 {i}:无效的音频数据")
print(f"成功处理 {len(tasks)} 个任务")
if len(tasks) == 0:
print("没有可用任务!")
exit()
return tasks