| | import os |
| | import csv |
| | import glob |
| | from tqdm import tqdm |
| | import torch |
| | import torchaudio |
| | from torchmetrics.audio import ScaleInvariantSignalDistortionRatio, SignalDistortionRatio |
| |
|
| |
|
| | def calculate_sdr_and_sisdr(original_audio_path, separated_audio_paths): |
| | """ |
| | 计算叠加的音频与原始音频之间的 SDR 和 SI-SDR。 |
| | |
| | 参数: |
| | - original_audio_path: str, 原始音频文件路径。 |
| | - separated_audio_paths: List[str], 分割后的音频片段文件路径列表。 |
| | |
| | 返回: |
| | - sdr: float, SDR 值。 |
| | - sisdr: float, SI-SDR 值。 |
| | """ |
| | |
| | original_waveform, sample_rate = torchaudio.load(original_audio_path) |
| |
|
| | |
| | combined_waveform = None |
| |
|
| | |
| | for path in separated_audio_paths: |
| | separated_waveform, _ = torchaudio.load(path) |
| |
|
| | |
| | min_length = min(original_waveform.size(1), separated_waveform.size(1)) |
| | separated_waveform = separated_waveform[:, :min_length] |
| |
|
| | |
| | if combined_waveform is None: |
| | combined_waveform = separated_waveform |
| | else: |
| | combined_waveform = combined_waveform[:, :min_length] + separated_waveform |
| |
|
| | |
| | min_length = min(original_waveform.size(1), combined_waveform.size(1)) |
| | original_waveform = original_waveform[:, :min_length] |
| | combined_waveform = combined_waveform[:, :min_length] |
| |
|
| | |
| | sisdr_metric = ScaleInvariantSignalDistortionRatio() |
| | sisdr = sisdr_metric(combined_waveform, original_waveform).item() |
| |
|
| | |
| | sdr_metric = SignalDistortionRatio() |
| | sdr = sdr_metric(combined_waveform, original_waveform).item() |
| |
|
| | |
| | |
| |
|
| | return sdr, sisdr |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | dset = 'balanced_train_segments' |
| | |
| | |
| | src_data_root = r'/data/sound/audioset/audios_32k' |
| | sep_data_root = r'data_engine_infer/audioset_separation_child_label' |
| | |
| | writer = csv.writer(open(os.path.join(sep_data_root, dset + '.csv'), 'w')) |
| | writer.writerow(['video', 'sdr', 'sisdr']) |
| | for video_path in tqdm(glob.glob(os.path.join(sep_data_root, dset, '*'))): |
| | video = video_path.split('/')[-1] |
| | original_audio_path = os.path.join(src_data_root, dset, video + '.wav') |
| | separated_audio_paths = glob.glob(video_path + '/*') |
| | sdr, sisdr = calculate_sdr_and_sisdr(original_audio_path, separated_audio_paths) |
| | writer.writerow([video, f'{sdr:.3f}', f'{sisdr:.3f}']) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|