diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py index b1052813c..0d9bd0a6a 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py @@ -24,13 +24,19 @@ Usage: export CUDA_VISIBLE_DEVICES="0,1,2,3" ./pruned_transducer_stateless5/train.py \ - --lang-dir data/lang_char \ --world-size 4 \ - --num-epochs 30 \ + --lang-dir data/lang_char \ + --num-epochs 40 \ --start-epoch 1 \ --exp-dir pruned_transducer_stateless5/exp \ - --full-libri 1 \ - --max-duration 300 + --max-duration 300 \ + --use-fp16 0 \ + --num-encoder-layers 24 \ + --dim-feedforward 1536 \ + --nhead 8 \ + --encoder-dim 384 \ + --decoder-dim 512 \ + --joiner-dim 512 # For mix precision training: @@ -41,7 +47,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --start-epoch 1 \ --use-fp16 1 \ --exp-dir pruned_transducer_stateless5/exp \ - --full-libri 1 \ --max-duration 550 """ @@ -84,6 +89,7 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool LRSchedulerType = Union[ @@ -773,7 +779,8 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() except: # noqa - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + display_and_save_batch(batch, params=params, + graph_compiler=graph_compiler) raise if params.print_diagnostics and batch_idx == 5: @@ -872,8 +879,6 @@ def run(rank, world_size, args): """ params = get_params() params.update(vars(args)) - if params.full_libri is False: - params.valid_interval = 1600 fix_random_seed(params.seed) if world_size > 1: @@ -951,7 +956,6 @@ def run(rank, world_size, args): train_cuts = aishell2.train_cuts() - def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds # @@ -976,7 +980,7 @@ def run(rank, world_size, args): train_cuts, sampler_state_dict=sampler_state_dict ) - valid_cuts = aishell2.dev_cuts() + valid_cuts = aishell2.valid_cuts() valid_dl = aishell2.valid_dataloaders(valid_cuts) if not params.print_diagnostics: @@ -1109,7 +1113,8 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + display_and_save_batch(batch, params=params, + graph_compiler=graph_compiler) raise