diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index 9ed9b2ad1..f347f552f 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -73,9 +73,9 @@ def greedy_search( continue # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :] + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) # fmt: on - logits = model.joiner(current_encoder_out, decoder_out) + logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1)) # logits is (1, 1, 1, vocab_size) y = logits.argmax().item() diff --git a/egs/librispeech/ASR/transducer_stateless/joiner.py b/egs/librispeech/ASR/transducer_stateless/joiner.py index 2ef3f1de6..89fab9e4c 100644 --- a/egs/librispeech/ASR/transducer_stateless/joiner.py +++ b/egs/librispeech/ASR/transducer_stateless/joiner.py @@ -30,21 +30,14 @@ class Joiner(nn.Module): """ Args: encoder_out: - Output from the encoder. Its shape is (N, T, C). + Output from the encoder. Its shape is (N, T, s_range, C). decoder_out: - Output from the decoder. Its shape is (N, U, C). + Output from the decoder. Its shape is (N, T, s_range, C). Returns: - Return a tensor of shape (N, T, U, C). + Return a tensor of shape (N, T, s_range, C). """ - assert encoder_out.ndim == decoder_out.ndim == 3 - assert encoder_out.size(0) == decoder_out.size(0) - assert encoder_out.size(2) == decoder_out.size(2) - - encoder_out = encoder_out.unsqueeze(2) - # Now encoder_out is (N, T, 1, C) - - decoder_out = decoder_out.unsqueeze(1) - # Now decoder_out is (N, 1, U, C) + assert encoder_out.ndim == decoder_out.ndim == 4 + assert encoder_out.shape == decoder_out.shape logit = encoder_out + decoder_out logit = torch.tanh(logit) diff --git a/egs/librispeech/ASR/transducer_stateless/model.py b/egs/librispeech/ASR/transducer_stateless/model.py index 2f0f9a183..837d76ddc 100644 --- a/egs/librispeech/ASR/transducer_stateless/model.py +++ b/egs/librispeech/ASR/transducer_stateless/model.py @@ -1,4 +1,4 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -14,15 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Note we use `rnnt_loss` from torchaudio, which exists only in -torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0 -""" + import k2 import torch import torch.nn as nn -import torchaudio -import torchaudio.functional from encoder_interface import EncoderInterface from icefall.utils import add_sos @@ -38,6 +33,7 @@ class Transducer(nn.Module): encoder: EncoderInterface, decoder: nn.Module, joiner: nn.Module, + prune_range: int = 3, ): """ Args: @@ -62,6 +58,7 @@ class Transducer(nn.Module): self.encoder = encoder self.decoder = decoder self.joiner = joiner + self.prune_range = prune_range def forward( self, @@ -102,24 +99,32 @@ class Transducer(nn.Module): decoder_out = self.decoder(sos_y_padded) - logits = self.joiner(encoder_out, decoder_out) - - # rnnt_loss requires 0 padded targets # Note: y does not start with SOS y_padded = y.pad(mode="constant", padding_value=0) - assert hasattr(torchaudio.functional, "rnnt_loss"), ( - f"Current torchaudio version: {torchaudio.__version__}\n" - "Please install a version >= 0.10.0" + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) + boundary[:, 2] = y_lens + boundary[:, 3] = x_lens + + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_simple( + decoder_out, encoder_out, y_padded, blank_id, boundary, True ) - loss = torchaudio.functional.rnnt_loss( - logits=logits, - targets=y_padded, - logit_lengths=x_lens, - target_lengths=y_lens, - blank=blank_id, - reduction="sum", + ranges = k2.get_rnnt_prune_ranges( + px_grad, py_grad, boundary, self.prune_range ) - return loss + am_pruning, lm_pruning = k2.do_rnnt_pruning( + encoder_out, decoder_out, ranges + ) + + logits = self.joiner(am_pruning, lm_pruning) + + pruning_loss = k2.rnnt_loss_pruned( + logits, y_padded, ranges, blank_id, boundary + ) + + return (-torch.sum(simple_loss), -torch.sum(pruning_loss)) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 694ebf1d5..3b92a5ee4 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -60,7 +60,15 @@ from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + measure_gradient_norms, + measure_weight_norms, + optim_step_and_measure_param_change, + setup_logger, + str2bool, +) def get_parser(): @@ -138,6 +146,14 @@ def get_parser(): "2 means tri-gram", ) + parser.add_argument( + "--prune-range", + type=int, + default=3, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + return parser @@ -195,6 +211,7 @@ def get_params() -> AttributeDict: "log_interval": 50, "reset_interval": 200, "valid_interval": 3000, # For the 100h subset, use 800 + "log_diagnostics": False, # parameters for conformer "feature_dim": 80, "encoder_out_dim": 512, @@ -383,7 +400,8 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - loss = model(x=feature, x_lens=feature_lens, y=y) + simple_loss, pruned_loss = model(x=feature, x_lens=feature_lens, y=y) + loss = simple_loss + pruned_loss assert loss.requires_grad == is_training @@ -392,6 +410,8 @@ def compute_loss( # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() return loss, info @@ -466,6 +486,45 @@ def train_one_epoch( tot_loss = MetricsTracker() + def maybe_log_gradients(tag: str): + if ( + params.log_diagnostics + and tb_writer is not None + and params.batch_idx_train % (params.log_interval * 5) == 0 + ): + tb_writer.add_scalars( + tag, + measure_gradient_norms(model, norm="l2"), + global_step=params.batch_idx_train, + ) + + def maybe_log_weights(tag: str): + if ( + params.log_diagnostics + and tb_writer is not None + and params.batch_idx_train % (params.log_interval * 5) == 0 + ): + tb_writer.add_scalars( + tag, + measure_weight_norms(model, norm="l2"), + global_step=params.batch_idx_train, + ) + + def maybe_log_param_relative_changes(): + if ( + params.log_diagnostics + and tb_writer is not None + and params.batch_idx_train % (params.log_interval * 5) == 0 + ): + deltas = optim_step_and_measure_param_change(model, optimizer) + tb_writer.add_scalars( + "train/relative_param_change_per_minibatch", + deltas, + global_step=params.batch_idx_train, + ) + else: + optimizer.step() + for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -483,10 +542,13 @@ def train_one_epoch( # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. - optimizer.zero_grad() loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() + + maybe_log_weights("train/param_norms") + maybe_log_gradients("train/grad_norms") + maybe_log_param_relative_changes() + + optimizer.zero_grad() if batch_idx % params.log_interval == 0: logging.info( diff --git a/icefall/utils.py b/icefall/utils.py index 7237c8d62..fa8f4d334 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -690,3 +690,94 @@ def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor: expaned_lengths = torch.arange(max_len).expand(n, max_len).to(lengths) return expaned_lengths >= lengths.unsqueeze(1) + + +def l1_norm(x): + return torch.sum(torch.abs(x)) + + +def l2_norm(x): + return torch.sum(torch.pow(x, 2)) + + +def linf_norm(x): + return torch.max(torch.abs(x)) + + +def measure_weight_norms( + model: nn.Module, norm: str = "l2" +) -> Dict[str, float]: + """ + Compute the norms of the model's parameters. + + :param model: a torch.nn.Module instance + :param norm: how to compute the norm. Available values: 'l1', 'l2', 'linf' + :return: a dict mapping from parameter's name to its norm. + """ + with torch.no_grad(): + norms = {} + for name, param in model.named_parameters(): + if norm == "l1": + val = l1_norm(param) + elif norm == "l2": + val = l2_norm(param) + elif norm == "linf": + val = linf_norm(param) + else: + raise ValueError(f"Unknown norm type: {norm}") + norms[name] = val.item() + return norms + + +def measure_gradient_norms( + model: nn.Module, norm: str = "l1" +) -> Dict[str, float]: + """ + Compute the norms of the gradients for each of model's parameters. + + :param model: a torch.nn.Module instance + :param norm: how to compute the norm. Available values: 'l1', 'l2', 'linf' + :return: a dict mapping from parameter's name to its gradient's norm. + """ + with torch.no_grad(): + norms = {} + for name, param in model.named_parameters(): + if norm == "l1": + val = l1_norm(param.grad) + elif norm == "l2": + val = l2_norm(param.grad) + elif norm == "linf": + val = linf_norm(param.grad) + else: + raise ValueError(f"Unknown norm type: {norm}") + norms[name] = val.item() + return norms + + +def optim_step_and_measure_param_change( + model: nn.Module, + optimizer: torch.optim.Optimizer, + scaler: Optional[GradScaler] = None, +) -> Dict[str, float]: + """ + Perform model weight update and measure the "relative change in parameters per minibatch." + It is understood as a ratio between the L2 norm of the difference between original and updates parameters, + and the L2 norm of the original parameter. It is given by the formula: + + .. math:: + \begin{aligned} + \delta = \frac{\Vert\theta - \theta_{new}\Vert^2}{\Vert\theta\Vert^2} + \end{aligned} + """ + param_copy = {n: p.detach().clone() for n, p in model.named_parameters()} + if scaler: + scaler.step(optimizer) + else: + optimizer.step() + relative_change = {} + with torch.no_grad(): + for n, p_new in model.named_parameters(): + p_orig = param_copy[n] + delta = l2_norm(p_orig - p_new) / l2_norm(p_orig) + relative_change[n] = delta.item() + return relative_change