WeNet / generate_data.py
inoryQwQ's picture
First commit
3c50954
import argparse
import os
import shutil
from tqdm import tqdm
from ort_common import WenetONNXRunner, pack_calibration_dataset
def get_args():
parser = argparse.ArgumentParser(
description="Generate calibration_dataset for exported ONNX models")
parser.add_argument("--input",
"-i",
nargs="+",
required=True,
help="Input wav file(s) or directory/directories")
parser.add_argument("--config",
required=True,
help="yaml file in checkpoint path")
parser.add_argument(
"--vocab",
required=True,
help="pretrained units.txt, for example pretrained/<model>/units.txt",
)
parser.add_argument("--onnx_dir",
default="onnx_model",
help="directory containing exported ONNX models")
parser.add_argument("--calib_data_path",
default="calibration_dataset",
help="output calibration dataset directory")
parser.add_argument("--parts",
nargs="+",
choices=["all", "offline", "online", "decoder"],
default=["all"],
help="which model inputs to generate")
parser.add_argument("--offline_seq_len", type=int, default=1024)
parser.add_argument("--decoder_len", type=int, default=32)
parser.add_argument("--decoding_chunk_size", type=int, default=16)
parser.add_argument("--num_decoding_left_chunks", type=int, default=5)
parser.add_argument("--max_num",
type=int,
default=100,
help="maximum number of audio files used for calibration; set <= 0 to use all")
parser.add_argument("--keep_existing",
action="store_true",
help="append to an existing calibration directory")
return parser.parse_args()
def expand_audio_inputs(inputs):
audio_exts = {".wav", ".flac", ".mp3", ".m4a", ".ogg"}
audio_files = []
for path in inputs:
if os.path.isdir(path):
for root, _, files in os.walk(path):
for filename in files:
if os.path.splitext(filename)[1].lower() in audio_exts:
audio_files.append(os.path.join(root, filename))
else:
audio_files.append(path)
audio_files = sorted(audio_files)
if not audio_files:
raise FileNotFoundError("No audio files found")
return audio_files
def normalize_parts(parts):
if "all" in parts:
return {"offline", "online", "decoder"}
return set(parts)
def limit_audio_files(audio_files, max_num):
if max_num is None or max_num <= 0:
return audio_files
return audio_files[:max_num]
def main():
args = get_args()
parts = normalize_parts(args.parts)
audio_files = limit_audio_files(expand_audio_inputs(args.input),
args.max_num)
if os.path.exists(args.calib_data_path) and not args.keep_existing:
shutil.rmtree(args.calib_data_path)
os.makedirs(args.calib_data_path, exist_ok=True)
runner = WenetONNXRunner(
args.config,
args.vocab,
onnx_dir=args.onnx_dir,
offline_seq_len=args.offline_seq_len,
decoder_len=args.decoder_len,
decoding_chunk_size=args.decoding_chunk_size,
num_decoding_left_chunks=args.num_decoding_left_chunks,
)
counts = {"offline": 0, "online": 0, "decoder": 0}
progress = tqdm(audio_files,
desc="Generating calibration data",
unit="wav")
for audio_idx, audio_file in enumerate(progress):
sample_counts = runner.save_calibration_for_audio(
audio_file, parts, args.calib_data_path, audio_idx)
for key, value in sample_counts.items():
counts[key] += value
progress.set_postfix(offline=counts["offline"],
online=counts["online"],
decoder=counts["decoder"])
print("Packing calibration dataset...")
pack_calibration_dataset(args.calib_data_path)
print(f"Generated calibration data in {args.calib_data_path}")
print(f"offline samples: {counts['offline']}")
print(f"online samples: {counts['online']}")
print(f"decoder samples: {counts['decoder']}")
if __name__ == "__main__":
main()