Fine-tune recipe for Zipformer (#1484)

1. support finetune zipformer
2. update the usage; set a very large batch count
This commit is contained in:
Xiaoyu Yang 2024-02-06 18:25:43 +08:00 committed by GitHub
parent a813186f64
commit 777074046d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 2664 additions and 6 deletions

View File

@ -51,6 +51,14 @@ def get_parser():
"Determines batch size dynamically.", "Determines batch size dynamically.",
) )
parser.add_argument(
"--subset",
type=str,
default="XL",
choices=["XL", "L", "M", "S", "XS"],
help="Which subset to work with",
)
parser.add_argument( parser.add_argument(
"--num-splits", "--num-splits",
type=int, type=int,
@ -76,7 +84,7 @@ def get_parser():
def compute_fbank_gigaspeech_splits(args): def compute_fbank_gigaspeech_splits(args):
num_splits = args.num_splits num_splits = args.num_splits
output_dir = "data/fbank/XL_split" output_dir = f"data/fbank/{args.subset}_split"
output_dir = Path(output_dir) output_dir = Path(output_dir)
assert output_dir.exists(), f"{output_dir} does not exist!" assert output_dir.exists(), f"{output_dir} does not exist!"
@ -96,15 +104,15 @@ def compute_fbank_gigaspeech_splits(args):
logging.info(f"device: {device}") logging.info(f"device: {device}")
for i in range(start, stop): for i in range(start, stop):
idx = f"{i + 1}".zfill(num_digits) idx = f"{i}".zfill(num_digits)
logging.info(f"Processing {idx}/{num_splits}") logging.info(f"Processing {idx}/{num_splits}")
cuts_path = output_dir / f"cuts_XL.{idx}.jsonl.gz" cuts_path = output_dir / f"cuts_{args.subset}.{idx}.jsonl.gz"
if cuts_path.is_file(): if cuts_path.is_file():
logging.info(f"{cuts_path} exists - skipping") logging.info(f"{cuts_path} exists - skipping")
continue continue
raw_cuts_path = output_dir / f"cuts_XL_raw.{idx}.jsonl.gz" raw_cuts_path = output_dir / f"cuts_{args.subset}_raw.{idx}.jsonl.gz"
logging.info(f"Loading {raw_cuts_path}") logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path) cut_set = CutSet.from_file(raw_cuts_path)
@ -113,7 +121,7 @@ def compute_fbank_gigaspeech_splits(args):
cut_set = cut_set.compute_and_store_features_batch( cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor, extractor=extractor,
storage_path=f"{output_dir}/feats_XL_{idx}", storage_path=f"{output_dir}/feats_{args.subset}_{idx}",
num_workers=args.num_workers, num_workers=args.num_workers,
batch_duration=args.batch_duration, batch_duration=args.batch_duration,
overwrite=True, overwrite=True,

View File

@ -1,4 +1,4 @@
# Copyright 2021 Piotr Żelasko # Copyright 2021 Piotr Żelasko
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) # Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
@ -475,3 +475,18 @@ class LibriSpeechAsrDataModule:
return load_manifest_lazy( return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
) )
@lru_cache()
def gigaspeech_subset_small_cuts(self) -> CutSet:
logging.info("About to get Gigaspeech subset-S cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_S.jsonl.gz")
@lru_cache()
def gigaspeech_dev_cuts(self) -> CutSet:
logging.info("About to get Gigaspeech dev cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz")
@lru_cache()
def gigaspeech_test_cuts(self) -> CutSet:
logging.info("About to get Gigaspeech test cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz")

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff