This commit is contained in:
Yuekai Zhang 2022-07-07 14:42:58 +08:00
parent a820c86337
commit 6d2641f2b9

View File

@ -24,13 +24,19 @@ Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless5/train.py \
--lang-dir data/lang_char \
--world-size 4 \
--num-epochs 30 \
--lang-dir data/lang_char \
--num-epochs 40 \
--start-epoch 1 \
--exp-dir pruned_transducer_stateless5/exp \
--full-libri 1 \
--max-duration 300
--max-duration 300 \
--use-fp16 0 \
--num-encoder-layers 24 \
--dim-feedforward 1536 \
--nhead 8 \
--encoder-dim 384 \
--decoder-dim 512 \
--joiner-dim 512
# For mix precision training:
@ -41,7 +47,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--start-epoch 1 \
--use-fp16 1 \
--exp-dir pruned_transducer_stateless5/exp \
--full-libri 1 \
--max-duration 550
"""
@ -84,6 +89,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[
@ -773,7 +779,8 @@ def train_one_epoch(
scaler.update()
optimizer.zero_grad()
except: # noqa
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
display_and_save_batch(batch, params=params,
graph_compiler=graph_compiler)
raise
if params.print_diagnostics and batch_idx == 5:
@ -872,8 +879,6 @@ def run(rank, world_size, args):
"""
params = get_params()
params.update(vars(args))
if params.full_libri is False:
params.valid_interval = 1600
fix_random_seed(params.seed)
if world_size > 1:
@ -951,7 +956,6 @@ def run(rank, world_size, args):
train_cuts = aishell2.train_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
#
@ -976,7 +980,7 @@ def run(rank, world_size, args):
train_cuts, sampler_state_dict=sampler_state_dict
)
valid_cuts = aishell2.dev_cuts()
valid_cuts = aishell2.valid_cuts()
valid_dl = aishell2.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
@ -1109,7 +1113,8 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
display_and_save_batch(batch, params=params,
graph_compiler=graph_compiler)
raise