From 8f21e92b5f4c3298d088554ae53629a9827b9f54 Mon Sep 17 00:00:00 2001 From: PingFeng Luo Date: Fri, 14 Jan 2022 12:12:38 +0800 Subject: [PATCH] Minor fixes --- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 8 +++++++- .../ASR/transducer_stateless/train.py | 19 +------------------ 2 files changed, 8 insertions(+), 19 deletions(-) diff --git a/egs/wenetspeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/wenetspeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index eb581a647..ee515a22b 100644 --- a/egs/wenetspeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/wenetspeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -29,6 +29,7 @@ from lhotse import ( load_manifest, set_caching_enabled, ) +from lhotse.cut import Cut from lhotse.dataset import ( DynamicBucketingSampler, CutConcatenate, @@ -101,7 +102,7 @@ class WenetSpeechDataModule: group.add_argument( "--num-buckets", type=int, - default=30, + default=300, help="The number of buckets for the DynamicBucketingSampler" "(you might want to increase it for larger datasets).", ) @@ -285,6 +286,11 @@ class WenetSpeechDataModule: ) logging.info("About to create train dataloader") + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + return 1.0 <= c.duration <= 16.0 + + train_sampler.filter(remove_short_and_long_utt) train_dl = DataLoader( train, sampler=train_sampler, diff --git a/egs/wenetspeech/ASR/transducer_stateless/train.py b/egs/wenetspeech/ASR/transducer_stateless/train.py index 9eb6deff4..5bcc82761 100755 --- a/egs/wenetspeech/ASR/transducer_stateless/train.py +++ b/egs/wenetspeech/ASR/transducer_stateless/train.py @@ -33,7 +33,6 @@ from asr_datamodule import WenetSpeechDataModule from conformer import Conformer from decoder import Decoder from joiner import Joiner -from lhotse.cut import Cut from lhotse.utils import fix_random_seed from model import Transducer from torch import Tensor @@ -196,7 +195,7 @@ def get_params() -> AttributeDict: "num_encoder_layers": 12, "vgg_frontend": False, # parameters for Noam - "warm_step": 80000, # For the 100h subset, use 8k + "warm_step": 1600000, # For the 100h subset, use 8k "env_info": get_env_info(), } ) @@ -591,22 +590,6 @@ def run(rank, world_size, args): wenetspeech = WenetSpeechDataModule(args) train_cuts = wenetspeech.train_cuts() - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - return 1.0 <= c.duration <= 20.0 - - num_in_total = len(train_cuts) - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - num_left = len(train_cuts) - num_removed = num_in_total - num_left - removed_percent = num_removed / num_in_total * 100 - - logging.info(f"Before removing short and long utterances: {num_in_total}") - logging.info(f"After removing short and long utterances: {num_left}") - logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") - train_dl = wenetspeech.train_dataloaders(train_cuts) valid_dl = wenetspeech.valid_dataloaders(wenetspeech.valid_cuts())