From ce2d8171147a00196ac7a3955f9292a9f7e6aa2c Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Sun, 17 Jul 2022 20:36:20 +0800 Subject: [PATCH] pruned2 -> pruned4 --- .../ASR/lstm_transducer_stateless/decode.py | 161 +++++++++----- .../streaming_decode.py | 141 +++++++++---- .../ASR/lstm_transducer_stateless/train.py | 196 ++++++++++-------- 3 files changed, 319 insertions(+), 179 deletions(-) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py index 34e8e8fb9..f7e3677da 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -19,7 +20,7 @@ Usage: (1) greedy search ./lstm_transducer_stateless/decode.py \ - --epoch 28 \ + --epoch 30 \ --avg 15 \ --exp-dir ./lstm_transducer_stateless/exp \ --max-duration 600 \ @@ -27,7 +28,7 @@ Usage: (2) beam search (not recommended) ./lstm_transducer_stateless/decode.py \ - --epoch 28 \ + --epoch 30 \ --avg 15 \ --exp-dir ./lstm_transducer_stateless/exp \ --max-duration 600 \ @@ -36,7 +37,7 @@ Usage: (3) modified beam search ./lstm_transducer_stateless/decode.py \ - --epoch 28 \ + --epoch 30 \ --avg 15 \ --exp-dir ./lstm_transducer_stateless/exp \ --max-duration 600 \ @@ -45,7 +46,7 @@ Usage: (4) fast beam search (one best) ./lstm_transducer_stateless/decode.py \ - --epoch 28 \ + --epoch 30 \ --avg 15 \ --exp-dir ./lstm_transducer_stateless/exp \ --max-duration 600 \ @@ -56,9 +57,9 @@ Usage: (5) fast beam search (nbest) ./lstm_transducer_stateless/decode.py \ - --epoch 28 \ + --epoch 30 \ --avg 15 \ - --exp-dir ./lstm_transducer_stateless/exp \ + --exp-dir ./pruned_transducer_stateless3/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest \ --beam 20.0 \ @@ -69,7 +70,7 @@ Usage: (6) fast beam search (nbest oracle WER) ./lstm_transducer_stateless/decode.py \ - --epoch 28 \ + --epoch 30 \ --avg 15 \ --exp-dir ./lstm_transducer_stateless/exp \ --max-duration 600 \ @@ -119,6 +120,7 @@ from train import get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, + average_checkpoints_with_averaged_model, find_checkpoints, load_checkpoint, ) @@ -127,6 +129,7 @@ from icefall.utils import ( AttributeDict, setup_logger, store_transcripts, + str2bool, write_error_stats, ) @@ -141,9 +144,9 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=28, + default=30, help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 0. + Note: Epoch counts from 1. You can specify --avg to use more checkpoints for model averaging.""", ) @@ -166,10 +169,21 @@ def get_parser(): "'--epoch' and '--iter'", ) + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless2/exp", + default="lstm_transducer_stateless/exp", help="The experiment dir", ) @@ -330,7 +344,7 @@ def decode_one_batch( Return the decoding result. See above description for the format of the returned dict. """ - device = model.device + device = next(model.parameters()).device feature = batch["inputs"] assert feature.ndim == 3 @@ -433,7 +447,7 @@ def decode_one_batch( for i in range(batch_size): # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] # fmt: on if params.decoding_method == "greedy_search": hyp = greedy_search( @@ -455,14 +469,6 @@ def decode_one_batch( if params.decoding_method == "greedy_search": return {"greedy_search": hyps} - elif params.decoding_method == "fast_beam_search": - return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps - } elif "fast_beam_search" in params.decoding_method: key = f"beam_{params.beam}_" key += f"max_contexts_{params.max_contexts}_" @@ -530,8 +536,8 @@ def decode_dataset( params=params, model=model, sp=sp, - word_table=word_table, decoding_graph=decoding_graph, + word_table=word_table, batch=batch, ) @@ -642,6 +648,9 @@ def main(): params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -659,43 +668,95 @@ def main(): params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() + if params.simulate_streaming: + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + logging.info(params) logging.info("About to create model") model = get_transducer_model(params) - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) model.to(device) model.eval() - model.device = device if "fast_beam_search" in params.decoding_method: if params.decoding_method == "fast_beam_search_nbest_LG": diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py index 6ba72ca2a..412718e33 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py @@ -18,12 +18,12 @@ """ Usage: ./lstm_transducer_stateless/streaming_decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./lstm_transducer_stateless/exp \ - --decoding_method greedy_search \ - --decode-chunk-size 1 \ - --num-decode-streams 1000 + --epoch 28 \ + --avg 15 \ + --decode-chunk-size 1 \ + --exp-dir ./lstm_transducer_stateless/exp \ + --decoding_method greedy_search \ + --num-decode-streams 200 """ import argparse @@ -46,6 +46,7 @@ from train import get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, + average_checkpoints_with_averaged_model, find_checkpoints, load_checkpoint, ) @@ -55,6 +56,7 @@ from icefall.utils import ( get_texts, setup_logger, store_transcripts, + str2bool, write_error_stats, ) @@ -94,6 +96,17 @@ def get_parser(): "'--epoch' and '--iter'", ) + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + parser.add_argument( "--exp-dir", type=str, @@ -164,6 +177,8 @@ def get_parser(): help="The number of streams that can be decoded parallel.", ) + add_model_arguments(parser) + return parser @@ -421,8 +436,8 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_states( - device=device + initial_states = model.encoder.get_init_state( + params.left_context, device=device ) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. @@ -507,8 +522,6 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) - # sort results so we can easily compare the difference between two - # recognition results results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -570,6 +583,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -587,39 +603,90 @@ def main(): params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() + # Decoding in streaming requires causal convolution + params.causal_convolution = True + logging.info(params) logging.info("About to create model") model = get_transducer_model(params) - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/train.py b/egs/librispeech/ASR/lstm_transducer_stateless/train.py index 8ce5bdc54..738a880eb 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/train.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang -# Mingshuang Luo) +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -24,7 +25,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" ./lstm_transducer_stateless/train.py \ --world-size 4 \ --num-epochs 30 \ - --start-epoch 0 \ + --start-epoch 1 \ --exp-dir lstm_transducer_stateless/exp \ --full-libri 1 \ --max-duration 300 @@ -34,15 +35,15 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" ./lstm_transducer_stateless/train.py \ --world-size 4 \ --num-epochs 30 \ - --start-epoch 0 \ + --start-epoch 1 \ --use-fp16 1 \ --exp-dir lstm_transducer_stateless/exp \ --full-libri 1 \ --max-duration 550 """ - import argparse +import copy import logging import warnings from pathlib import Path @@ -72,7 +73,10 @@ from torch.utils.tensorboard import SummaryWriter from icefall import diagnostics from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import save_checkpoint_with_global_batch_idx +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool @@ -118,10 +122,10 @@ def get_parser(): parser.add_argument( "--start-epoch", type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - transducer_stateless2/exp/epoch-{start_epoch-1}.pt + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt """, ) @@ -137,7 +141,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless2/exp", + default="lstm_transducer_stateless/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -155,16 +159,16 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need to " - "be changed.", + help="""The initial learning rate. This value should not need to be + changed.""", ) parser.add_argument( "--lr-batches", type=float, default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", + help="""Number of steps that affects how rapidly the learning rate decreases. + We suggest not to change this.""", ) parser.add_argument( @@ -255,6 +259,19 @@ def get_parser(): """, ) + parser.add_argument( + "--average-period", + type=int, + default=100, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + parser.add_argument( "--use-fp16", type=str2bool, @@ -390,6 +407,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, + model_avg: nn.Module = None, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, ) -> Optional[Dict[str, Any]]: @@ -397,7 +415,7 @@ def load_checkpoint_if_available( If params.start_batch is positive, it will load the checkpoint from `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is positive, it will load the checkpoint from + params.start_epoch is larger than 1, it will load the checkpoint from `params.start_epoch - 1`. Apart from loading state dict for `model` and `optimizer` it also updates @@ -409,6 +427,8 @@ def load_checkpoint_if_available( The return value of :func:`get_params`. model: The training model. + model_avg: + The stored model averaged from the start of training. optimizer: The optimizer that we are using. scheduler: @@ -418,7 +438,7 @@ def load_checkpoint_if_available( """ if params.start_batch > 0: filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 0: + elif params.start_epoch > 1: filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" else: return None @@ -428,6 +448,7 @@ def load_checkpoint_if_available( saved_params = load_checkpoint( filename, model=model, + model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, ) @@ -454,7 +475,8 @@ def load_checkpoint_if_available( def save_checkpoint( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, @@ -468,6 +490,8 @@ def save_checkpoint( It is returned by :func:`get_params`. model: The training model. + model_avg: + The stored model averaged from the start of training. optimizer: The optimizer used in the training. sampler: @@ -481,6 +505,7 @@ def save_checkpoint( save_checkpoint_impl( filename=filename, model=model, + model_avg=model_avg, params=params, optimizer=optimizer, scheduler=scheduler, @@ -500,7 +525,7 @@ def save_checkpoint( def compute_loss( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, @@ -524,7 +549,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -580,7 +609,7 @@ def compute_loss( def compute_validation_loss( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], sp: spm.SentencePieceProcessor, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, @@ -614,13 +643,14 @@ def compute_validation_loss( def train_one_epoch( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], optimizer: torch.optim.Optimizer, scheduler: LRSchedulerType, sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, + model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, rank: int = 0, @@ -646,6 +676,8 @@ def train_one_epoch( Dataloader for the validation dataset. scaler: The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. tb_writer: Writer to write log messages to tensorboard. world_size: @@ -668,33 +700,40 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - warmup=(params.batch_idx_train / params.model_warm_step), - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=(params.batch_idx_train / params.model_warm_step), + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params, sp=sp) - raise + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() if params.print_diagnostics and batch_idx == 30: return + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + if ( params.batch_idx_train > 0 and params.batch_idx_train % params.save_every_n == 0 @@ -704,6 +743,7 @@ def train_one_epoch( out_dir=params.exp_dir, global_batch_idx=params.batch_idx_train, model=model, + model_avg=model_avg, params=params, optimizer=optimizer, scheduler=scheduler, @@ -803,11 +843,6 @@ def run(rank, world_size, args): params.blank_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() - if params.dynamic_chunk_training: - assert ( - params.causal_convolution - ), "dynamic_chunk_training requires causal convolution" - logging.info(params) logging.info("About to create model") @@ -816,13 +851,21 @@ def run(rank, world_size, args): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoints = load_checkpoint_if_available(params=params, model=model) + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) model.to(device) if world_size > 1: logging.info("Using DDP") model = DDP(model, device_ids=[rank]) - model.device = device optimizer = Eve(model.parameters(), lr=params.initial_lr) @@ -885,7 +928,7 @@ def run(rank, world_size, args): optimizer=optimizer, sp=sp, params=params, - warmup=0.0 if params.start_epoch == 0 else 1.0, + warmup=0.0 if params.start_epoch == 1 else 1.0, ) scaler = GradScaler(enabled=params.use_fp16) @@ -893,10 +936,10 @@ def run(rank, world_size, args): logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) - for epoch in range(params.start_epoch, params.num_epochs): - scheduler.step_epoch(epoch) - fix_random_seed(params.seed + epoch) - train_dl.sampler.set_epoch(epoch) + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) if tb_writer is not None: tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) @@ -906,6 +949,7 @@ def run(rank, world_size, args): train_one_epoch( params=params, model=model, + model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, sp=sp, @@ -924,6 +968,7 @@ def run(rank, world_size, args): save_checkpoint( params=params, model=model, + model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, sampler=train_dl.sampler, @@ -938,40 +983,8 @@ def run(rank, world_size, args): cleanup_dist() -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - def scan_pessimistic_batches_for_oom( - model: nn.Module, + model: Union[nn.Module, DDP], train_dl: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, @@ -981,7 +994,7 @@ def scan_pessimistic_batches_for_oom( from lhotse.dataset import find_pessimistic_batches logging.info( - "Sanity check -- see if any of the batches in epoch 0 would cause OOM." + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." ) batches, crit_values = find_pessimistic_batches(train_dl.sampler) for criterion, cuts in batches.items(): @@ -999,7 +1012,7 @@ def scan_pessimistic_batches_for_oom( loss.backward() optimizer.step() optimizer.zero_grad() - except Exception as e: + except RuntimeError as e: if "CUDA out of memory" in str(e): logging.error( "Your GPU ran out of memory with the current " @@ -1008,7 +1021,6 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) - display_and_save_batch(batch, params=params, sp=sp) raise