This commit is contained in:
Yuekai Zhang 2022-07-07 14:42:58 +08:00
parent a820c86337
commit 6d2641f2b9

View File

@ -24,13 +24,19 @@ Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3" export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless5/train.py \ ./pruned_transducer_stateless5/train.py \
--lang-dir data/lang_char \
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --lang-dir data/lang_char \
--num-epochs 40 \
--start-epoch 1 \ --start-epoch 1 \
--exp-dir pruned_transducer_stateless5/exp \ --exp-dir pruned_transducer_stateless5/exp \
--full-libri 1 \ --max-duration 300 \
--max-duration 300 --use-fp16 0 \
--num-encoder-layers 24 \
--dim-feedforward 1536 \
--nhead 8 \
--encoder-dim 384 \
--decoder-dim 512 \
--joiner-dim 512
# For mix precision training: # For mix precision training:
@ -41,7 +47,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--start-epoch 1 \ --start-epoch 1 \
--use-fp16 1 \ --use-fp16 1 \
--exp-dir pruned_transducer_stateless5/exp \ --exp-dir pruned_transducer_stateless5/exp \
--full-libri 1 \
--max-duration 550 --max-duration 550
""" """
@ -84,6 +89,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[ LRSchedulerType = Union[
@ -773,7 +779,8 @@ def train_one_epoch(
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
except: # noqa except: # noqa
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) display_and_save_batch(batch, params=params,
graph_compiler=graph_compiler)
raise raise
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5:
@ -872,8 +879,6 @@ def run(rank, world_size, args):
""" """
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
if params.full_libri is False:
params.valid_interval = 1600
fix_random_seed(params.seed) fix_random_seed(params.seed)
if world_size > 1: if world_size > 1:
@ -951,7 +956,6 @@ def run(rank, world_size, args):
train_cuts = aishell2.train_cuts() train_cuts = aishell2.train_cuts()
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds # Keep only utterances with duration between 1 second and 20 seconds
# #
@ -976,7 +980,7 @@ def run(rank, world_size, args):
train_cuts, sampler_state_dict=sampler_state_dict train_cuts, sampler_state_dict=sampler_state_dict
) )
valid_cuts = aishell2.dev_cuts() valid_cuts = aishell2.valid_cuts()
valid_dl = aishell2.valid_dataloaders(valid_cuts) valid_dl = aishell2.valid_dataloaders(valid_cuts)
if not params.print_diagnostics: if not params.print_diagnostics:
@ -1109,7 +1113,8 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} " f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..." f"(={crit_values[criterion]}) ..."
) )
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) display_and_save_batch(batch, params=params,
graph_compiler=graph_compiler)
raise raise