From 2c8d65904c7dee6a74adc4bb84e0007abc5cf92a Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Fri, 21 Jul 2023 01:54:35 +0800 Subject: [PATCH] updated --- .../ASR/zipformer/multi_dataset.py | 22 +++++++++---------- egs/multi_zh-hans/ASR/zipformer/train.py | 20 ++++++++++------- 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py b/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py index d957c6cc0..3fabac6d3 100644 --- a/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py +++ b/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py @@ -216,14 +216,14 @@ class MultiDataset: self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz" ) - - return [ - aidatatang_dev_cuts, - aishell_dev_cuts, - aishell2_dev_cuts, - alimeeting_dev_cuts, - magicdata_dev_cuts, - kespeech_dev_phase1_cuts, - kespeech_dev_phase2_cuts, - wenetspeech_dev_cuts, - ] + return wenetspeech_dev_cuts + # return [ + # aidatatang_dev_cuts, + # aishell_dev_cuts, + # aishell2_dev_cuts, + # alimeeting_dev_cuts, + # magicdata_dev_cuts, + # kespeech_dev_phase1_cuts, + # kespeech_dev_phase2_cuts, + # wenetspeech_dev_cuts, + # ] diff --git a/egs/multi_zh-hans/ASR/zipformer/train.py b/egs/multi_zh-hans/ASR/zipformer/train.py index bc3e9c1ba..1b7b4cc83 100755 --- a/egs/multi_zh-hans/ASR/zipformer/train.py +++ b/egs/multi_zh-hans/ASR/zipformer/train.py @@ -72,6 +72,7 @@ from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import AsrModel +from multi_dataset import MultiDataset from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling @@ -324,7 +325,7 @@ def get_parser(): parser.add_argument( "--bpe-model", type=str, - default="data/lang_bpe_500/bpe.model", + default="data/lang_bpe_2000/bpe.model", help="Path to the BPE model", ) @@ -1174,11 +1175,13 @@ def run(rank, world_size, args): register_inf_check_hooks(model) librispeech = LibriSpeechAsrDataModule(args) + multi_dataset = MultiDataset(args.manifest_dir) - train_cuts = librispeech.train_clean_100_cuts() - if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = multi_dataset.train_cuts() + # train_cuts = librispeech.train_clean_100_cuts() + # if params.full_libri: + # train_cuts += librispeech.train_clean_360_cuts() + # train_cuts += librispeech.train_other_500_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -1189,7 +1192,7 @@ def run(rank, world_size, args): # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - if c.duration < 1.0 or c.duration > 20.0: + if c.duration < 1.0 or c.duration > 600.0: # logging.warning( # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" # ) @@ -1230,8 +1233,9 @@ def run(rank, world_size, args): train_cuts, sampler_state_dict=sampler_state_dict ) - valid_cuts = librispeech.dev_clean_cuts() - valid_cuts += librispeech.dev_other_cuts() + # valid_cuts = librispeech.dev_clean_cuts() + # valid_cuts += librispeech.dev_other_cuts() + valid_cuts = multi_dataset.dev_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) if not params.print_diagnostics: