mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
fix lint
This commit is contained in:
parent
a820c86337
commit
6d2641f2b9
@ -24,13 +24,19 @@ Usage:
|
|||||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
|
||||||
./pruned_transducer_stateless5/train.py \
|
./pruned_transducer_stateless5/train.py \
|
||||||
--lang-dir data/lang_char \
|
|
||||||
--world-size 4 \
|
--world-size 4 \
|
||||||
--num-epochs 30 \
|
--lang-dir data/lang_char \
|
||||||
|
--num-epochs 40 \
|
||||||
--start-epoch 1 \
|
--start-epoch 1 \
|
||||||
--exp-dir pruned_transducer_stateless5/exp \
|
--exp-dir pruned_transducer_stateless5/exp \
|
||||||
--full-libri 1 \
|
--max-duration 300 \
|
||||||
--max-duration 300
|
--use-fp16 0 \
|
||||||
|
--num-encoder-layers 24 \
|
||||||
|
--dim-feedforward 1536 \
|
||||||
|
--nhead 8 \
|
||||||
|
--encoder-dim 384 \
|
||||||
|
--decoder-dim 512 \
|
||||||
|
--joiner-dim 512
|
||||||
|
|
||||||
# For mix precision training:
|
# For mix precision training:
|
||||||
|
|
||||||
@ -41,7 +47,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
--start-epoch 1 \
|
--start-epoch 1 \
|
||||||
--use-fp16 1 \
|
--use-fp16 1 \
|
||||||
--exp-dir pruned_transducer_stateless5/exp \
|
--exp-dir pruned_transducer_stateless5/exp \
|
||||||
--full-libri 1 \
|
|
||||||
--max-duration 550
|
--max-duration 550
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -84,6 +89,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
|
|
||||||
LRSchedulerType = Union[
|
LRSchedulerType = Union[
|
||||||
@ -773,7 +779,8 @@ def train_one_epoch(
|
|||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
except: # noqa
|
except: # noqa
|
||||||
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
|
display_and_save_batch(batch, params=params,
|
||||||
|
graph_compiler=graph_compiler)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if params.print_diagnostics and batch_idx == 5:
|
if params.print_diagnostics and batch_idx == 5:
|
||||||
@ -872,8 +879,6 @@ def run(rank, world_size, args):
|
|||||||
"""
|
"""
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
if params.full_libri is False:
|
|
||||||
params.valid_interval = 1600
|
|
||||||
|
|
||||||
fix_random_seed(params.seed)
|
fix_random_seed(params.seed)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
@ -951,7 +956,6 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
train_cuts = aishell2.train_cuts()
|
train_cuts = aishell2.train_cuts()
|
||||||
|
|
||||||
|
|
||||||
def remove_short_and_long_utt(c: Cut):
|
def remove_short_and_long_utt(c: Cut):
|
||||||
# Keep only utterances with duration between 1 second and 20 seconds
|
# Keep only utterances with duration between 1 second and 20 seconds
|
||||||
#
|
#
|
||||||
@ -976,7 +980,7 @@ def run(rank, world_size, args):
|
|||||||
train_cuts, sampler_state_dict=sampler_state_dict
|
train_cuts, sampler_state_dict=sampler_state_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_cuts = aishell2.dev_cuts()
|
valid_cuts = aishell2.valid_cuts()
|
||||||
valid_dl = aishell2.valid_dataloaders(valid_cuts)
|
valid_dl = aishell2.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if not params.print_diagnostics:
|
if not params.print_diagnostics:
|
||||||
@ -1109,7 +1113,8 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
f"Failing criterion: {criterion} "
|
f"Failing criterion: {criterion} "
|
||||||
f"(={crit_values[criterion]}) ..."
|
f"(={crit_values[criterion]}) ..."
|
||||||
)
|
)
|
||||||
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
|
display_and_save_batch(batch, params=params,
|
||||||
|
graph_compiler=graph_compiler)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user