diff --git a/egs/mls/ASR/local/compute_fbank_mls_splits.py b/egs/mls/ASR/local/compute_fbank_mls_splits.py new file mode 100755 index 000000000..ed9e7492c --- /dev/null +++ b/egs/mls/ASR/local/compute_fbank_mls_splits.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python3 +# Copyright 2021 Johns Hopkins University (Piotr Żelasko) +# Copyright 2021 Xiaomi Corp. (Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from datetime import datetime +from pathlib import Path + +import torch +from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--num-workers", + type=int, + default=20, + help="Number of dataloading workers used for reading the audio.", + ) + parser.add_argument( + "--batch-duration", + type=float, + default=600.0, + help="The maximum number of audio seconds in a batch." + "Determines batch size dynamically.", + ) + + parser.add_argument( + "--language", + type=str, + default="english", + ) + + parser.add_argument( + "--num-splits", + type=int, + required=True, + help="The number of splits of the XL subset", + ) + + parser.add_argument( + "--start", + type=int, + default=0, + help="Process pieces starting from this number (inclusive).", + ) + + parser.add_argument( + "--stop", + type=int, + default=-1, + help="Stop processing pieces until this number (exclusive).", + ) + + parser.add_argument( + "--fbank-dir", + type=str, + default="data/fbank_mls" + ) + return parser + + +def compute_fbank_mls_splits(args): + num_splits = args.num_splits + output_dir = f"{args.fbank_dir}/{args.language}_split" + output_dir = Path(output_dir) + assert output_dir.exists(), f"{output_dir} does not exist!" + + num_digits = 8 # num_digits is fixed by lhotse split-lazy + + start = args.start + stop = args.stop + if stop < start: + stop = num_splits + + stop = min(stop, num_splits) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) + logging.info(f"device: {device}") + + for i in range(start, stop): + idx = f"{i}".zfill(num_digits) + logging.info(f"Processing {idx}/{num_splits}") + + cuts_path = output_dir / f"cuts_{args.subset}.{idx}.jsonl.gz" + if cuts_path.is_file(): + logging.info(f"{cuts_path} exists - skipping") + continue + + raw_cuts_path = output_dir / f"cuts_{args.subset}_raw.{idx}.jsonl.gz" + + logging.info(f"Loading {raw_cuts_path}") + cut_set = CutSet.from_file(raw_cuts_path) + + logging.info("Computing features") + + cut_set = cut_set.compute_and_store_features_batch( + extractor=extractor, + storage_path=f"{output_dir}/feats_{args.subset}_{idx}", + num_workers=args.num_workers, + batch_duration=args.batch_duration, + overwrite=True, + ) + + logging.info("About to split cuts into smaller chunks.") + cut_set = cut_set.trim_to_supervisions( + keep_overlapping=False, min_duration=None + ) + + logging.info(f"Saving to {cuts_path}") + cut_set.to_file(cuts_path) + logging.info(f"Saved to {cuts_path}") + + +def main(): + now = datetime.now() + date_time = now.strftime("%Y-%m-%d-%H-%M-%S") + + log_filename = "log-compute_fbank_mls_splits" + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + log_filename = f"{log_filename}-{date_time}" + + logging.basicConfig( + filename=log_filename, + format=formatter, + level=logging.INFO, + filemode="w", + ) + + console = logging.StreamHandler() + console.setLevel(logging.INFO) + console.setFormatter(logging.Formatter(formatter)) + logging.getLogger("").addHandler(console) + + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + compute_fbank_mls_splits(args) + + +if __name__ == "__main__": + main() diff --git a/egs/mls/ASR/local/validate_manifest.py b/egs/mls/ASR/local/validate_manifest.py new file mode 120000 index 000000000..0a9725e87 --- /dev/null +++ b/egs/mls/ASR/local/validate_manifest.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/validate_manifest.py \ No newline at end of file diff --git a/egs/mls/ASR/zipformer/beam_search.py b/egs/mls/ASR/zipformer/beam_search.py new file mode 120000 index 000000000..e24eca39f --- /dev/null +++ b/egs/mls/ASR/zipformer/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file