From 1651fe0d42c562c2fd562b2927a2d835c679ec8c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 31 May 2022 13:00:11 +0800 Subject: [PATCH] Merge changes from pruned_transducer_stateless4->5 --- .../pruned_transducer_stateless2/export.py | 144 +++++--- .../pruned_transducer_stateless7/__init__.py | 1 - .../pruned_transducer_stateless7/conformer.py | 313 +++++++++++++++++- .../pruned_transducer_stateless7/decode.py | 34 +- .../test_model.py | 31 +- .../ASR/pruned_transducer_stateless7/train.py | 169 +++++++--- 6 files changed, 569 insertions(+), 123 deletions(-) mode change 120000 => 100644 egs/librispeech/ASR/pruned_transducer_stateless7/__init__.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py index cff9c7377..f1269a4bd 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py @@ -20,26 +20,27 @@ # to a single one using model averaging. """ Usage: -./pruned_transducer_stateless2/export.py \ - --exp-dir ./pruned_transducer_stateless2/exp \ +./pruned_transducer_stateless5/export.py \ + --exp-dir ./pruned_transducer_stateless5/exp \ --bpe-model data/lang_bpe_500/bpe.model \ --epoch 20 \ --avg 10 It will generate a file exp_dir/pretrained.pt -To use the generated file with `pruned_transducer_stateless2/decode.py`, +To use the generated file with `pruned_transducer_stateless5/decode.py`, you can do: cd /path/to/exp_dir ln -s pretrained.pt epoch-9999.pt cd /path/to/egs/librispeech/ASR - ./pruned_transducer_stateless2/decode.py \ - --exp-dir ./pruned_transducer_stateless2/exp \ + ./pruned_transducer_stateless5/decode.py \ + --exp-dir ./pruned_transducer_stateless5/exp \ --epoch 9999 \ --avg 1 \ - --max-duration 100 \ + --max-duration 600 \ + --decoding-method greedy_search \ --bpe-model data/lang_bpe_500/bpe.model """ @@ -49,10 +50,11 @@ from pathlib import Path import sentencepiece as spm import torch -from train import get_params, get_transducer_model +from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, + average_checkpoints_with_averaged_model, find_checkpoints, load_checkpoint, ) @@ -69,7 +71,7 @@ def get_parser(): type=int, default=28, help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 0. + Note: Epoch counts from 1. You can specify --avg to use more checkpoints for model averaging.""", ) @@ -92,10 +94,21 @@ def get_parser(): "'--epoch' and '--iter'", ) + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=False, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless2/exp", + default="pruned_transducer_stateless5/exp", help="""It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved """, @@ -124,6 +137,8 @@ def get_parser(): "2 means tri-gram", ) + add_model_arguments(parser) + return parser @@ -131,6 +146,8 @@ def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) + assert args.jit is False, "Support torchscript will be added later" + params = get_params() params.update(vars(args)) @@ -152,36 +169,82 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - model.to(device) - - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) model.eval() @@ -189,11 +252,6 @@ def main(): model.eval() if params.jit: - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) filename = params.exp_dir / "cpu_jit.pt" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/__init__.py b/egs/librispeech/ASR/pruned_transducer_stateless7/__init__.py deleted file mode 120000 index b24e5e357..000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/__init__.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless2/__init__.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/__init__.py b/egs/librispeech/ASR/pruned_transducer_stateless7/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 99a6f7ce5..00ec14408 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -18,7 +18,7 @@ import copy import math import warnings -from typing import Optional, Tuple +from typing import List, Optional, Tuple import torch from encoder_interface import EncoderInterface @@ -61,6 +61,7 @@ class Conformer(EncoderInterface): dropout: float = 0.1, layer_dropout: float = 0.075, cnn_module_kernel: int = 31, + aux_layer_period: int = 3, ) -> None: super(Conformer, self).__init__() @@ -86,7 +87,11 @@ class Conformer(EncoderInterface): layer_dropout, cnn_module_kernel, ) - self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) + self.encoder = ConformerEncoder( + encoder_layer, + num_encoder_layers, + aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)), + ) def forward( self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 @@ -112,13 +117,10 @@ class Conformer(EncoderInterface): x, pos_emb = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - # Caution: We assume the subsampling factor is 4! - - # lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning - # - # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 - lengths = (((x_lens - 1) >> 1) - 1) >> 1 - + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # Caution: We assume the subsampling factor is 4! + lengths = ((x_lens - 1) // 2 - 1) // 2 assert x.size(0) == lengths.max().item() mask = make_pad_mask(lengths) @@ -282,13 +284,30 @@ class ConformerEncoder(nn.Module): >>> out = conformer_encoder(src, pos_emb) """ - def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + aux_layers: List[int], + ) -> None: super().__init__() self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] ) self.num_layers = num_layers + assert num_layers - 1 not in aux_layers + self.aux_layers = set(aux_layers + [num_layers - 1]) + + num_channels = encoder_layer.norm_final.num_channels + self.combiner = RandomCombine( + num_inputs=len(self.aux_layers), + num_channels=num_channels, + final_weight=0.5, + pure_prob=0.333, + stddev=2.0, + ) + def forward( self, src: Tensor, @@ -315,6 +334,8 @@ class ConformerEncoder(nn.Module): """ output = src + outputs = [] + for i, mod in enumerate(self.layers): output = mod( output, @@ -323,6 +344,10 @@ class ConformerEncoder(nn.Module): src_key_padding_mask=src_key_padding_mask, warmup=warmup, ) + if i in self.aux_layers: + outputs.append(output) + + output = self.combiner(outputs) return output @@ -1022,15 +1047,281 @@ class Conv2dSubsampling(nn.Module): x = self.out_balancer(x) return x +class RandomCombine(nn.Module): + """ + This module combines a list of Tensors, all with the same shape, to + produce a single output of that same shape which, in training time, + is a random combination of all the inputs; but which in test time + will be just the last input. -if __name__ == "__main__": + All but the last input will have a linear transform before we + randomly combine them; these linear transforms will be initialized + to the identity transform. + + The idea is that the list of Tensors will be a list of outputs of multiple + conformer layers. This has a similar effect as iterated loss. (See: + DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER + NETWORKS). + """ + + def __init__( + self, + num_inputs: int, + num_channels: int, + final_weight: float = 0.5, + pure_prob: float = 0.5, + stddev: float = 2.0, + ) -> None: + """ + Args: + num_inputs: + The number of tensor inputs, which equals the number of layers' + outputs that are fed into this module. E.g. in an 18-layer neural + net if we output layers 16, 12, 18, num_inputs would be 3. + num_channels: + The number of channels on the input, e.g. 512. + final_weight: + The amount of weight or probability we assign to the + final layer when randomly choosing layers or when choosing + continuous layer weights. + pure_prob: + The probability, on each frame, with which we choose + only a single layer to output (rather than an interpolation) + stddev: + A standard deviation that we add to log-probs for computing + randomized weights. + + The method of choosing which layers, or combinations of layers, to use, + is conceptually as follows:: + + With probability `pure_prob`:: + With probability `final_weight`: choose final layer, + Else: choose random non-final layer. + Else:: + Choose initial log-weights that correspond to assigning + weight `final_weight` to the final layer and equal + weights to other layers; then add Gaussian noise + with variance `stddev` to these log-weights, and normalize + to weights (note: the average weight assigned to the + final layer here will not be `final_weight` if stddev>0). + """ + super().__init__() + assert 0 <= pure_prob <= 1, pure_prob + assert 0 < final_weight < 1, final_weight + assert num_inputs >= 1 + + self.linear = nn.ModuleList( + [ + nn.Linear(num_channels, num_channels, bias=True) + for _ in range(num_inputs - 1) + ] + ) + + self.num_inputs = num_inputs + self.final_weight = final_weight + self.pure_prob = pure_prob + self.stddev = stddev + + self.final_log_weight = ( + torch.tensor( + (final_weight / (1 - final_weight)) * (self.num_inputs - 1) + ) + .log() + .item() + ) + self._reset_parameters() + + def _reset_parameters(self): + for i in range(len(self.linear)): + nn.init.eye_(self.linear[i].weight) + nn.init.constant_(self.linear[i].bias, 0.0) + + def forward(self, inputs: List[Tensor]) -> Tensor: + """Forward function. + Args: + inputs: + A list of Tensor, e.g. from various layers of a transformer. + All must be the same shape, of (*, num_channels) + Returns: + A Tensor of shape (*, num_channels). In test mode + this is just the final input. + """ + num_inputs = self.num_inputs + assert len(inputs) == num_inputs + if not self.training: + return inputs[-1] + + # Shape of weights: (*, num_inputs) + num_channels = inputs[0].shape[-1] + num_frames = inputs[0].numel() // num_channels + + mod_inputs = [] + for i in range(num_inputs - 1): + mod_inputs.append(self.linear[i](inputs[i])) + mod_inputs.append(inputs[num_inputs - 1]) + + ndim = inputs[0].ndim + # stacked_inputs: (num_frames, num_channels, num_inputs) + stacked_inputs = torch.stack(mod_inputs, dim=ndim).reshape( + (num_frames, num_channels, num_inputs) + ) + + # weights: (num_frames, num_inputs) + weights = self._get_random_weights( + inputs[0].dtype, inputs[0].device, num_frames + ) + + weights = weights.reshape(num_frames, num_inputs, 1) + # ans: (num_frames, num_channels, 1) + ans = torch.matmul(stacked_inputs, weights) + # ans: (*, num_channels) + ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels) + + if __name__ == "__main__": + # for testing only... + print("Weights = ", weights.reshape(num_frames, num_inputs)) + return ans + + def _get_random_weights( + self, dtype: torch.dtype, device: torch.device, num_frames: int + ) -> Tensor: + """Return a tensor of random weights, of shape + `(num_frames, self.num_inputs)`, + Args: + dtype: + The data-type desired for the answer, e.g. float, double. + device: + The device needed for the answer. + num_frames: + The number of sets of weights desired + Returns: + A tensor of shape (num_frames, self.num_inputs), such that + `ans.sum(dim=1)` is all ones. + """ + pure_prob = self.pure_prob + if pure_prob == 0.0: + return self._get_random_mixed_weights(dtype, device, num_frames) + elif pure_prob == 1.0: + return self._get_random_pure_weights(dtype, device, num_frames) + else: + p = self._get_random_pure_weights(dtype, device, num_frames) + m = self._get_random_mixed_weights(dtype, device, num_frames) + return torch.where( + torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m + ) + + def _get_random_pure_weights( + self, dtype: torch.dtype, device: torch.device, num_frames: int + ): + """Return a tensor of random one-hot weights, of shape + `(num_frames, self.num_inputs)`, + Args: + dtype: + The data-type desired for the answer, e.g. float, double. + device: + The device needed for the answer. + num_frames: + The number of sets of weights desired. + Returns: + A one-hot tensor of shape `(num_frames, self.num_inputs)`, with + exactly one weight equal to 1.0 on each frame. + """ + final_prob = self.final_weight + + # final contains self.num_inputs - 1 in all elements + final = torch.full((num_frames,), self.num_inputs - 1, device=device) + # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. + nonfinal = torch.randint( + self.num_inputs - 1, (num_frames,), device=device + ) + + indexes = torch.where( + torch.rand(num_frames, device=device) < final_prob, final, nonfinal + ) + ans = torch.nn.functional.one_hot( + indexes, num_classes=self.num_inputs + ).to(dtype=dtype) + return ans + + def _get_random_mixed_weights( + self, dtype: torch.dtype, device: torch.device, num_frames: int + ): + """Return a tensor of random one-hot weights, of shape + `(num_frames, self.num_inputs)`, + Args: + dtype: + The data-type desired for the answer, e.g. float, double. + device: + The device needed for the answer. + num_frames: + The number of sets of weights desired. + Returns: + A tensor of shape (num_frames, self.num_inputs), which elements + in [0..1] that sum to one over the second axis, i.e. + `ans.sum(dim=1)` is all ones. + """ + logprobs = ( + torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device) + * self.stddev + ) + logprobs[:, -1] += self.final_log_weight + return logprobs.softmax(dim=1) + + +def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): + print( + f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}" + ) + num_inputs = 3 + num_channels = 50 + m = RandomCombine( + num_inputs=num_inputs, + num_channels=num_channels, + final_weight=final_weight, + pure_prob=pure_prob, + stddev=stddev, + ) + + x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)] + + y = m(x) + assert y.shape == x[0].shape + assert torch.allclose(y, x[0]) # .. since actually all ones. + + +def _test_random_combine_main(): + _test_random_combine(0.999, 0, 0.0) + _test_random_combine(0.5, 0, 0.0) + _test_random_combine(0.999, 0, 0.0) + _test_random_combine(0.5, 0, 0.3) + _test_random_combine(0.5, 1, 0.3) + _test_random_combine(0.5, 0.5, 0.3) + + +def _test_conformer_main(): feature_dim = 50 c = Conformer(num_features=feature_dim, d_model=128, nhead=4) batch_size = 5 seq_len = 20 + feature_dim = 50 + # Just make sure the forward pass runs. + + c = Conformer( + num_features=feature_dim, d_model=128, nhead=4 + ) + batch_size = 5 + seq_len = 20 # Just make sure the forward pass runs. f = c( torch.randn(batch_size, seq_len, feature_dim), torch.full((batch_size,), seq_len, dtype=torch.int64), warmup=0.5, ) + f # to remove flake8 warnings + + + + +if __name__ == "__main__": + _test_conformer_main() + _test_random_combine_main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py index d1af63aaa..c35429263 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py @@ -19,36 +19,36 @@ """ Usage: (1) greedy search -./pruned_transducer_stateless4/decode.py \ - --epoch 30 \ +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless4/exp \ + --exp-dir ./pruned_transducer_stateless7/exp \ --max-duration 600 \ --decoding-method greedy_search (2) beam search (not recommended) -./pruned_transducer_stateless4/decode.py \ - --epoch 30 \ +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless4/exp \ + --exp-dir ./pruned_transducer_stateless7/exp \ --max-duration 600 \ --decoding-method beam_search \ --beam-size 4 (3) modified beam search -./pruned_transducer_stateless4/decode.py \ - --epoch 30 \ +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless4/exp \ + --exp-dir ./pruned_transducer_stateless7/exp \ --max-duration 600 \ --decoding-method modified_beam_search \ --beam-size 4 (4) fast beam search -./pruned_transducer_stateless4/decode.py \ - --epoch 30 \ +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless4/exp \ + --exp-dir ./pruned_transducer_stateless7/exp \ --max-duration 600 \ --decoding-method fast_beam_search \ --beam 4 \ @@ -75,7 +75,7 @@ from beam_search import ( greedy_search_batch, modified_beam_search, ) -from train import get_params, get_transducer_model +from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, @@ -139,7 +139,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless4/exp", + default="pruned_transducer_stateless7/exp", help="The experiment dir", ) @@ -212,6 +212,8 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) + add_model_arguments(parser) + return parser @@ -302,7 +304,7 @@ def decode_one_batch( for i in range(batch_size): # fmt: off - encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] # fmt: on if params.decoding_method == "greedy_search": hyp = greedy_search( @@ -374,7 +376,7 @@ def decode_dataset( if params.decoding_method == "greedy_search": log_interval = 50 else: - log_interval = 10 + log_interval = 20 results = defaultdict(list) for batch_idx, batch in enumerate(dl): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py index b1832d0ec..9aad32014 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py @@ -23,27 +23,42 @@ To run this file, do: python ./pruned_transducer_stateless4/test_model.py """ -import torch from train import get_params, get_transducer_model -def test_model(): +def test_model_1(): params = get_params() params.vocab_size = 500 params.blank_id = 0 params.context_size = 2 - params.unk_id = 2 - + params.num_encoder_layers = 24 + params.dim_feedforward = 1536 # 384 * 4 + params.encoder_dim = 384 + model = get_transducer_model(params) + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + +# See Table 1 from https://arxiv.org/pdf/2005.08100.pdf +def test_model_M(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = 18 + params.dim_feedforward = 1024 + params.encoder_dim = 256 + params.nhead = 4 + params.decoder_dim = 512 + params.joiner_dim = 512 model = get_transducer_model(params) - num_param = sum([p.numel() for p in model.parameters()]) print(f"Number of model parameters: {num_param}") - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - torch.jit.script(model) def main(): - test_model() + # test_model_1() + test_model_M() if __name__ == "__main__": diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 2e8d64971..b2f1cc792 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -22,22 +22,22 @@ Usage: export CUDA_VISIBLE_DEVICES="0,1,2,3" -./pruned_transducer_stateless4/train.py \ +./pruned_transducer_stateless7/train.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 1 \ - --exp-dir pruned_transducer_stateless2/exp \ + --exp-dir pruned_transducer_stateless7/exp \ --full-libri 1 \ --max-duration 300 # For mix precision training: -./pruned_transducer_stateless4/train.py \ +./pruned_transducer_stateless7/train.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 1 \ --use-fp16 1 \ - --exp-dir pruned_transducer_stateless2/exp \ + --exp-dir pruned_transducer_stateless7/exp \ --full-libri 1 \ --max-duration 550 @@ -88,6 +88,53 @@ LRSchedulerType = Union[ ] +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=int, + default=24, + help="Number of conformer encoder layers..", + ) + + parser.add_argument( + "--dim-feedforward", + type=int, + default=1536, + help="Feedforward dimension of the conformer encoder layer.", + ) + + parser.add_argument( + "--nhead", + type=int, + default=8, + help="Number of attention heads in the conformer encoder layer.", + ) + + parser.add_argument( + "--encoder-dim", + type=int, + default=384, + help="Attention dimension in the conformer encoder layer.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -143,7 +190,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless2/exp", + default="pruned_transducer_stateless7/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -161,16 +208,16 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="""The initial learning rate. This value should not need to be - changed.""", + help="The initial learning rate. This value should not need " + "to be changed.", ) parser.add_argument( "--lr-batches", type=float, default=5000, - help="""Number of steps that affects how rapidly the learning rate decreases. - We suggest not to change this.""", + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", ) parser.add_argument( @@ -240,7 +287,7 @@ def get_parser(): parser.add_argument( "--save-every-n", type=int, - default=8000, + default=4000, help="""Save checkpoint after processing this number of batches" periodically. We save checkpoint to exp-dir/ whenever params.batch_idx_train % save_every_n == 0. The checkpoint filename @@ -253,7 +300,7 @@ def get_parser(): parser.add_argument( "--keep-last-k", type=int, - default=20, + default=30, help="""Only keep this number of checkpoints on disk. For instance, if it is 3, there are only 3 checkpoints in the exp-dir with filenames `checkpoint-xxx.pt`. @@ -281,6 +328,8 @@ def get_parser(): help="Whether to use half precision training.", ) + add_model_arguments(parser) + return parser @@ -341,14 +390,6 @@ def get_params() -> AttributeDict: # parameters for conformer "feature_dim": 80, "subsampling_factor": 4, - "encoder_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - # parameters for decoder - "decoder_dim": 512, - # parameters for joiner - "joiner_dim": 512, # parameters for Noam "model_warm_step": 3000, # arg given to model, not for lrate "env_info": get_env_info(), @@ -704,27 +745,31 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - warmup=(params.batch_idx_train / params.model_warm_step), - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=(params.batch_idx_train / params.model_warm_step), + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise - if params.print_diagnostics and batch_idx == 30: + if params.print_diagnostics and batch_idx == 5: return if ( @@ -888,7 +933,10 @@ def run(rank, world_size, args): scheduler.load_state_dict(checkpoints["scheduler"]) if params.print_diagnostics: - diagnostic = diagnostics.attach_diagnostics(model) + opts = diagnostics.TensorDiagnosticOptions( + 2 ** 22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) librispeech = LibriSpeechAsrDataModule(args) @@ -986,6 +1034,38 @@ def run(rank, world_size, args): cleanup_dist() +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + def scan_pessimistic_batches_for_oom( model: Union[nn.Module, DDP], train_dl: torch.utils.data.DataLoader, @@ -1017,7 +1097,7 @@ def scan_pessimistic_batches_for_oom( loss.backward() optimizer.step() optimizer.zero_grad() - except RuntimeError as e: + except Exception as e: if "CUDA out of memory" in str(e): logging.error( "Your GPU ran out of memory with the current " @@ -1026,6 +1106,7 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) + display_and_save_batch(batch, params=params, sp=sp) raise