diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py index 0dade163b..85101a697 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py @@ -830,7 +830,9 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() except: # noqa - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + display_and_save_batch( + batch, params=params, graph_compiler=graph_compiler + ) raise if params.print_diagnostics and batch_idx == 5: @@ -929,7 +931,7 @@ def run(rank, world_size, args): """ params = get_params() params.update(vars(args)) - + fix_random_seed(params.seed) if world_size > 1: setup_dist(rank, world_size, params.master_port) @@ -1169,7 +1171,9 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + display_and_save_batch( + batch, params=params, graph_compiler=graph_compiler + ) raise