From 6d809bad0bfe6ea1beb36eabf7c71e2c5f7716e9 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 5 May 2022 22:06:37 +0800 Subject: [PATCH] Merge changes from master. --- .flake8 | 2 + egs/librispeech/ASR/README.md | 2 + .../ASR/pruned_transducer_stateless4/train.py | 2 +- .../decode-giga.py | 93 +++++++++++++------ .../pruned_transducer_stateless5/decode.py | 93 +++++++++++++------ .../pruned_transducer_stateless5/sampling.py | 6 +- .../ASR/pruned_transducer_stateless5/train.py | 81 +++++++++++++--- icefall/checkpoint.py | 4 +- 8 files changed, 208 insertions(+), 75 deletions(-) diff --git a/.flake8 b/.flake8 index 190387886..8b165135e 100644 --- a/.flake8 +++ b/.flake8 @@ -20,4 +20,6 @@ exclude = .git, **/data/**, icefall/shared/make_kn_lm.py, + egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py, + egs/librispeech/ASR/pruned_transducer_stateless5/sampling.py, icefall/__init__.py diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index c053076a3..14dbfe95f 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -19,6 +19,8 @@ The following table lists the differences among them. | `pruned_transducer_stateless` | Conformer | Embedding + Conv1d | Using k2 pruned RNN-T loss | | `pruned_transducer_stateless2` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss | | `pruned_transducer_stateless3` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss + using GigaSpeech as extra training data | +| `pruned_transducer_stateless4` | Conformer(modified) | Embedding + Conv1d | Same as pruned_transducer_stateless2 but supports saving averaged model periodically.| +| `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | Same as pruned_transducer_stateless3 but with knowledge bank| The decoder in `transducer_stateless` is modified from the paper diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 147bcf658..de126a8e3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -411,7 +411,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, - model_avg: nn.Module = None, + model_avg: Optional[nn.Module] = None, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, ) -> Optional[Dict[str, Any]]: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode-giga.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode-giga.py index 9ae17fd11..5d51ab478 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode-giga.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode-giga.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 # -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang +# Zengwei Yao) +# # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -81,6 +83,7 @@ from train import get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, + average_checkpoints_with_averaged_model, find_checkpoints, load_checkpoint, ) @@ -88,6 +91,7 @@ from icefall.utils import ( AttributeDict, setup_logger, store_transcripts, + str2bool, write_error_stats, ) @@ -102,7 +106,7 @@ def get_parser(): type=int, default=28, help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 0. + Note: Epoch counts from 1. You can specify --avg to use more checkpoints for model averaging.""", ) @@ -125,6 +129,17 @@ def get_parser(): "'--epoch' and '--iter'", ) + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=False, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + parser.add_argument( "--exp-dir", type=str, @@ -538,6 +553,9 @@ def main(): params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -560,34 +578,53 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") + assert params.iter == 0 and params.avg > 0 + start = params.epoch - params.avg + assert start >= 1 + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 865709833..1cfdd57a3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 # -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang +# Zengwei Yao) +# # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -80,6 +82,7 @@ from train import get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, + average_checkpoints_with_averaged_model, find_checkpoints, load_checkpoint, ) @@ -87,6 +90,7 @@ from icefall.utils import ( AttributeDict, setup_logger, store_transcripts, + str2bool, write_error_stats, ) @@ -101,7 +105,7 @@ def get_parser(): type=int, default=28, help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 0. + Note: Epoch counts from 1. You can specify --avg to use more checkpoints for model averaging.""", ) @@ -124,6 +128,17 @@ def get_parser(): "'--epoch' and '--iter'", ) + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=False, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + parser.add_argument( "--exp-dir", type=str, @@ -525,6 +540,9 @@ def main(): params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -547,34 +565,53 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") + assert params.iter == 0 and params.avg > 0 + start = params.epoch - params.avg + assert start >= 1 + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/sampling.py b/egs/librispeech/ASR/pruned_transducer_stateless5/sampling.py index 26f0d26b0..d53062c84 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/sampling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/sampling.py @@ -86,8 +86,8 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): tensor of shape (*, D), containing weighted sums of rows of `knowledge_base` """ - if random.random() < 0.001: - print("dtype[1] = ", weights.dtype) + # if random.random() < 0.001: + # print("dtype[1] = ", weights.dtype) ctx.save_for_backward(weights.detach(), indexes.detach(), knowledge_base.detach()) with torch.no_grad(): @@ -174,7 +174,7 @@ class KnowledgeBaseLookup(nn.Module): assert torch.all(x - x == 0) if random.random() < 0.001: entropy = (x * x.exp()).sum(dim=-1).mean() - print("Entropy = ", entropy) + # print("Entropy = ", entropy) # only need 'combined_indexes', call them 'indexes'. _, indexes, weights = sample_combined(x, self.K, input_is_log=True) x = weighted_matrix_lookup(weights, indexes, self.knowledge_base) # now (*, D) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index dcedcfec6..62eeffb0e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang # Wei Kang -# Mingshuang Luo) +# Mingshuang Luo +# Zengwei Yao) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -48,6 +49,7 @@ cd egs/librispeech/ASR/ import argparse +import copy import logging import random import warnings @@ -81,7 +83,10 @@ from torch.utils.tensorboard import SummaryWriter from icefall import diagnostics from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import save_checkpoint_with_global_batch_idx +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool @@ -135,10 +140,10 @@ def get_parser(): parser.add_argument( "--start-epoch", type=int, - default=0, + default=1, help="""Resume training from from this epoch. If it is positive, it will load checkpoint from - transducer_stateless3/exp/epoch-{start_epoch-1}.pt + exp-dir/epoch-{start_epoch-1}.pt """, ) @@ -272,6 +277,19 @@ def get_parser(): """, ) + parser.add_argument( + "--average-period", + type=int, + default=1000, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + parser.add_argument( "--use-fp16", type=str2bool, @@ -423,6 +441,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, + model_avg: Optional[nn.Module] = None, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, ) -> Optional[Dict[str, Any]]: @@ -430,7 +449,7 @@ def load_checkpoint_if_available( If params.start_batch is positive, it will load the checkpoint from `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is positive, it will load the checkpoint from + params.start_epoch is larger than 1, it will load the checkpoint from `params.start_epoch - 1`. Apart from loading state dict for `model` and `optimizer` it also updates @@ -442,6 +461,8 @@ def load_checkpoint_if_available( The return value of :func:`get_params`. model: The training model. + model_avg: + The stored model averaged from the start of training. optimizer: The optimizer that we are using. scheduler: @@ -451,7 +472,7 @@ def load_checkpoint_if_available( """ if params.start_batch > 0: filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 0: + elif params.start_epoch > 1: filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" else: return None @@ -461,6 +482,7 @@ def load_checkpoint_if_available( saved_params = load_checkpoint( filename, model=model, + model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, ) @@ -485,6 +507,7 @@ def load_checkpoint_if_available( def save_checkpoint( params: AttributeDict, model: nn.Module, + model_avg: Optional[nn.Module] = None, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, @@ -498,6 +521,8 @@ def save_checkpoint( It is returned by :func:`get_params`. model: The training model. + model_avg: + The stored model averaged from the start of training. optimizer: The optimizer used in the training. sampler: @@ -511,6 +536,7 @@ def save_checkpoint( save_checkpoint_impl( filename=filename, model=model, + model_avg=model_avg, params=params, optimizer=optimizer, scheduler=scheduler, @@ -667,6 +693,7 @@ def train_one_epoch( valid_dl: torch.utils.data.DataLoader, rng: random.Random, scaler: GradScaler, + model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, rank: int = 0, @@ -696,6 +723,8 @@ def train_one_epoch( For selecting which dataset to use. scaler: The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. tb_writer: Writer to write log messages to tensorboard. world_size: @@ -772,6 +801,17 @@ def train_one_epoch( if params.print_diagnostics and batch_idx == 5: return + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + if ( params.batch_idx_train > 0 and params.batch_idx_train % params.save_every_n == 0 @@ -780,6 +820,7 @@ def train_one_epoch( out_dir=params.exp_dir, global_batch_idx=params.batch_idx_train, model=model, + model_avg=model_avg, params=params, optimizer=optimizer, scheduler=scheduler, @@ -915,7 +956,15 @@ def run(rank, world_size, args): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoints = load_checkpoint_if_available(params=params, model=model) + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model) + + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) model.to(device) if world_size > 1: @@ -923,6 +972,10 @@ def run(rank, world_size, args): model = DDP(model, device_ids=[rank], find_unused_parameters=True) model.device = device + if rank == 0: + model_avg.to(device) + model_avg.device = device + optimizer = Eve(model.parameters(), lr=params.initial_lr) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) @@ -1014,10 +1067,10 @@ def run(rank, world_size, args): logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) - for epoch in range(params.start_epoch, params.num_epochs): - scheduler.step_epoch(epoch) - fix_random_seed(params.seed + epoch) - train_dl.sampler.set_epoch(epoch) + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) if tb_writer is not None: tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) @@ -1027,6 +1080,7 @@ def run(rank, world_size, args): train_one_epoch( params=params, model=model, + model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, sp=sp, @@ -1047,6 +1101,7 @@ def run(rank, world_size, args): save_checkpoint( params=params, model=model, + model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, sampler=train_dl.sampler, @@ -1071,7 +1126,7 @@ def scan_pessimistic_batches_for_oom( from lhotse.dataset import find_pessimistic_batches logging.info( - "Sanity check -- see if any of the batches in epoch 0 would cause OOM." + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." ) batches, crit_values = find_pessimistic_batches(train_dl.sampler) for criterion, cuts in batches.items(): diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index 5b562ccc8..2ca173663 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -346,7 +346,7 @@ def remove_checkpoints( for c in to_remove: os.remove(c) - +@torch.no_grad() def update_averaged_model( params: Dict[str, Tensor], model_cur: Union[nn.Module, DDP], @@ -442,7 +442,7 @@ def average_checkpoints_with_averaged_model( return avg - +@torch.no_grad() def average_state_dict( state_dict_1: Dict[str, Tensor], state_dict_2: Dict[str, Tensor],