From 06bca2ffedcebae69682194c7a969969a4e86ab5 Mon Sep 17 00:00:00 2001 From: jinzr Date: Fri, 15 Mar 2024 10:43:33 +0800 Subject: [PATCH] misc. update --- .../local/compute_fbank_commonvoice_splits.py | 9 +++++---- egs/commonvoice/ASR/prepare.sh | 18 +++++++++++++++--- .../asr_datamodule.py | 8 ++++++++ .../ASR/pruned_transducer_stateless7/train.py | 12 ++++++++++++ .../train.py | 12 ++++++++++++ egs/commonvoice/ASR/zipformer/train.py | 12 ++++++++++++ egs/commonvoice/ASR/zipformer/train_char.py | 12 ++++++++++++ 7 files changed, 76 insertions(+), 7 deletions(-) diff --git a/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py index fd565da83..14fb9b446 100755 --- a/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py +++ b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (Yifan Yang) +# Copyright 2023-2024 Xiaomi Corp. (Yifan Yang, +# Zengrui Jin,) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -74,21 +75,21 @@ def get_args(): "--num-splits", type=int, required=True, - help="The number of splits of the train subset", + help="The number of splits of the subset", ) parser.add_argument( "--start", type=int, default=0, - help="Process pieces starting from this number (inclusive).", + help="Process pieces starting from this number (included).", ) parser.add_argument( "--stop", type=int, default=-1, - help="Stop processing pieces until this number (exclusive).", + help="Stop processing pieces until this number (excluded).", ) parser.add_argument( diff --git a/egs/commonvoice/ASR/prepare.sh b/egs/commonvoice/ASR/prepare.sh index 985351e22..4e76ef041 100755 --- a/egs/commonvoice/ASR/prepare.sh +++ b/egs/commonvoice/ASR/prepare.sh @@ -257,12 +257,14 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then log "Also combine features for validated data" pieces=$(find data/${lang}/fbank/cv-${lang}_validated_split_${num_splits} -name "cv-${lang}_cuts_validated.*.jsonl.gz") lhotse combine $pieces data/${lang}/fbank/cv-${lang}_cuts_validated.jsonl.gz + touch data/${lang}/fbank/.cv-${lang}_validated.done fi if [ $use_invalidated = true ] && [ -f data/${lang}/fbank/.cv-${lang}_invalidated.done ]; then log "Also combine features for invalidated data" - pieces=$(find data/${lang}/fbank/cv-${lang}_inalidated_split_${num_splits} -name "cv-${lang}_cuts_invalidated.*.jsonl.gz") + pieces=$(find data/${lang}/fbank/cv-${lang}_invalidated_split_${num_splits} -name "cv-${lang}_cuts_invalidated.*.jsonl.gz") lhotse combine $pieces data/${lang}/fbank/cv-${lang}_cuts_invalidated.jsonl.gz + touch data/${lang}/fbank/.cv-${lang}_invalidated.done fi fi @@ -289,8 +291,18 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then # 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64 # 2. chmod +x ./jq # 3. cp jq /usr/bin - gunzip -c data/${lang}/manifests/cv-${lang}_supervisions_train.jsonl.gz \ - | jq '.text' | sed 's/"//g' > $lang_dir/text + if [ $use_validated = true ]; then + gunzip -c data/${lang}/manifests/cv-${lang}_supervisions_validated.jsonl.gz \ + | jq '.text' | sed 's/"//g' >> $lang_dir/text + else + gunzip -c data/${lang}/manifests/cv-${lang}_supervisions_train.jsonl.gz \ + | jq '.text' | sed 's/"//g' > $lang_dir/text + fi + + if [ $use_invalidated = true ]; then + gunzip -c data/${lang}/manifests/cv-${lang}_supervisions_invalidated.jsonl.gz \ + | jq '.text' | sed 's/"//g' >> $lang_dir/text + fi if [ $lang == "yue" ] || [ $lang == "zh-HK" ]; then # Get words.txt and words_no_ids.txt diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py index bf0a3e245..c4797b945 100644 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py @@ -417,6 +417,14 @@ class CommonVoiceAsrDataModule: / f"cv-{self.args.language}_cuts_validated.jsonl.gz" ) + @lru_cache() + def validated_cuts(self) -> CutSet: + logging.info("About to get invalidated cuts") + return load_manifest_lazy( + self.args.cv_manifest_dir + / f"cv-{self.args.language}_cuts_invalidated.jsonl.gz" + ) + @lru_cache() def dev_cuts(self) -> CutSet: logging.info("About to get dev cuts") diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py index a3cda30a6..dc124a3b5 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py @@ -258,6 +258,15 @@ def get_parser(): """, ) + parser.add_argument( + "--use-invalidated-set", + type=str2bool, + default=False, + help="""Use the invalidated set for training. + In case you want to take the risk and utilize more data for training. + """, + ) + parser.add_argument( "--base-lr", type=float, @@ -1047,6 +1056,9 @@ def run(rank, world_size, args): else: train_cuts = commonvoice.validated_cuts() + if args.use_invalidated_set: + train_cuts += commonvoice.invalidated_cuts() + def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds # diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py index f1ad3a8f4..79908f204 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py @@ -274,6 +274,15 @@ def get_parser(): """, ) + parser.add_argument( + "--use-invalidated-set", + type=str2bool, + default=False, + help="""Use the invalidated set for training. + In case you want to take the risk and utilize more data for training. + """, + ) + parser.add_argument( "--base-lr", type=float, @@ -1064,6 +1073,9 @@ def run(rank, world_size, args): else: train_cuts = commonvoice.validated_cuts() + if args.use_invalidated_set: + train_cuts += commonvoice.invalidated_cuts() + def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds # diff --git a/egs/commonvoice/ASR/zipformer/train.py b/egs/commonvoice/ASR/zipformer/train.py index 5f543ff3f..160d490e0 100755 --- a/egs/commonvoice/ASR/zipformer/train.py +++ b/egs/commonvoice/ASR/zipformer/train.py @@ -337,6 +337,15 @@ def get_parser(): """, ) + parser.add_argument( + "--use-invalidated-set", + type=str2bool, + default=False, + help="""Use the invalidated set for training. + In case you want to take the risk and utilize more data for training. + """, + ) + parser.add_argument( "--base-lr", type=float, @@ -1191,6 +1200,9 @@ def run(rank, world_size, args): else: train_cuts = commonvoice.validated_cuts() + if args.use_invalidated_set: + train_cuts += commonvoice.invalidated_cuts() + def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds # diff --git a/egs/commonvoice/ASR/zipformer/train_char.py b/egs/commonvoice/ASR/zipformer/train_char.py index b69ebd085..a780bbbbc 100755 --- a/egs/commonvoice/ASR/zipformer/train_char.py +++ b/egs/commonvoice/ASR/zipformer/train_char.py @@ -184,6 +184,15 @@ def get_parser(): """, ) + parser.add_argument( + "--use-invalidated-set", + type=str2bool, + default=False, + help="""Use the invalidated set for training. + In case you want to take the risk and utilize more data for training. + """, + ) + parser.add_argument( "--base-lr", type=float, @@ -904,6 +913,9 @@ def run(rank, world_size, args): else: train_cuts = commonvoice.validated_cuts() + if args.use_invalidated_set: + train_cuts += commonvoice.invalidated_cuts() + def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds #