From 8541dc73f9c0e72b70055e549315331a1d697f29 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 28 Dec 2021 20:11:01 +0800 Subject: [PATCH 1/3] WIP: Use optimized_transducer to compute transducer loss. --- .../ASR/transducer_stateless/joiner.py | 40 ++++++++++++++----- .../ASR/transducer_stateless/model.py | 24 +++++------ 2 files changed, 41 insertions(+), 23 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/joiner.py b/egs/librispeech/ASR/transducer_stateless/joiner.py index 2ef3f1de6..6ee22deec 100644 --- a/egs/librispeech/ASR/transducer_stateless/joiner.py +++ b/egs/librispeech/ASR/transducer_stateless/joiner.py @@ -22,32 +22,50 @@ class Joiner(nn.Module): def __init__(self, input_dim: int, output_dim: int): super().__init__() + self.input_dim = input_dim + self.output_dim = output_dim self.output_linear = nn.Linear(input_dim, output_dim) def forward( - self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + encoder_out_len: torch.Tensor, + decoder_out_len: torch.Tensor, ) -> torch.Tensor: """ Args: encoder_out: - Output from the encoder. Its shape is (N, T, C). + Output from the encoder. Its shape is (N, T, self.input_dim). decoder_out: - Output from the decoder. Its shape is (N, U, C). + Output from the decoder. Its shape is (N, U, self.input_dim). Returns: - Return a tensor of shape (N, T, U, C). + Return a tensor of shape (sum_all_TU, self.output_dim). """ assert encoder_out.ndim == decoder_out.ndim == 3 assert encoder_out.size(0) == decoder_out.size(0) - assert encoder_out.size(2) == decoder_out.size(2) + assert encoder_out.size(2) == self.input_dim + assert decoder_out.size(2) == self.input_dim - encoder_out = encoder_out.unsqueeze(2) - # Now encoder_out is (N, T, 1, C) + N = encoder_out.size(0) - decoder_out = decoder_out.unsqueeze(1) - # Now decoder_out is (N, 1, U, C) + encoder_out_list = [ + encoder_out[i, : encoder_out_len[i], :] for i in range(N) + ] - logit = encoder_out + decoder_out - logit = torch.tanh(logit) + decoder_out_list = [ + decoder_out[i, : decoder_out_len[i], :] for i in range(N) + ] + + x = [ + e.unsqueeze(1) + d.unsqueeze(0) + for e, d in zip(encoder_out_list, decoder_out_list) + ] + + x = [p.reshape(-1, self.input_dim) for p in x] + x = torch.cat(x) + + logit = torch.tanh(x) output = self.output_linear(logit) diff --git a/egs/librispeech/ASR/transducer_stateless/model.py b/egs/librispeech/ASR/transducer_stateless/model.py index 2f0f9a183..98a6f0f37 100644 --- a/egs/librispeech/ASR/transducer_stateless/model.py +++ b/egs/librispeech/ASR/transducer_stateless/model.py @@ -14,15 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Note we use `rnnt_loss` from torchaudio, which exists only in -torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0 -""" import k2 import torch import torch.nn as nn -import torchaudio -import torchaudio.functional from encoder_interface import EncoderInterface from icefall.utils import add_sos @@ -102,18 +96,24 @@ class Transducer(nn.Module): decoder_out = self.decoder(sos_y_padded) - logits = self.joiner(encoder_out, decoder_out) + # +1 here since a blank is prepended to each utterance. + logits = self.joiner( + encoder_out=encoder_out, + decoder_out=decoder_out, + encoder_out_len=x_lens, + decoder_out_len=y_lens + 1, + ) # rnnt_loss requires 0 padded targets # Note: y does not start with SOS y_padded = y.pad(mode="constant", padding_value=0) - assert hasattr(torchaudio.functional, "rnnt_loss"), ( - f"Current torchaudio version: {torchaudio.__version__}\n" - "Please install a version >= 0.10.0" - ) + # We don't put this `import` at the beginning of the file + # as it is required only in the training, not during the + # reference stage + import optimized_transducer - loss = torchaudio.functional.rnnt_loss( + loss = optimized_transducer.transducer_loss( logits=logits, targets=y_padded, logit_lengths=x_lens, From 7828c6ff7325e37f4afab94a65f6d049f3d3e05d Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 30 Dec 2021 00:15:23 +0800 Subject: [PATCH 2/3] Minor fixes. --- egs/librispeech/ASR/transducer_stateless/joiner.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/joiner.py b/egs/librispeech/ASR/transducer_stateless/joiner.py index 6ee22deec..9fd9da4f1 100644 --- a/egs/librispeech/ASR/transducer_stateless/joiner.py +++ b/egs/librispeech/ASR/transducer_stateless/joiner.py @@ -65,8 +65,8 @@ class Joiner(nn.Module): x = [p.reshape(-1, self.input_dim) for p in x] x = torch.cat(x) - logit = torch.tanh(x) + activations = torch.tanh(x) - output = self.output_linear(logit) + logits = self.output_linear(activations) - return output + return logits From b49510e2bf7064f4f60650e6787288db1bad2941 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 31 Dec 2021 15:52:33 +0800 Subject: [PATCH 3/3] Add label smoothing for transducer loss. --- .../ASR/transducer_stateless/model.py | 67 ++++++++++++++++++- .../ASR/transducer_stateless/train.py | 14 +++- 2 files changed, 79 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/model.py b/egs/librispeech/ASR/transducer_stateless/model.py index 98a6f0f37..8cd406df0 100644 --- a/egs/librispeech/ASR/transducer_stateless/model.py +++ b/egs/librispeech/ASR/transducer_stateless/model.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math + import k2 import torch import torch.nn as nn @@ -22,6 +24,61 @@ from encoder_interface import EncoderInterface from icefall.utils import add_sos +def reverse_label_smoothing( + logprobs: torch.Tensor, alpha: float +) -> torch.Tensor: + """ + This function is written by Dan. + + Modifies `logprobs` in such a way that if you compute a data probability + using `logprobs`, it will be equivalent to a label-smoothed data probability + with the supplied label-smoothing constant alpha (e.g. alpha=0.1). + This allows us to use `logprobs` in things like RNN-T and CTC and + get a kind of label-smoothed version of those sequence objectives. + + Label smoothing means that if the reference label is i, we convert it + into a distribution with weight (1-alpha) on i, and alpha distributed + equally to all labels (including i itself). + + Note: the output logprobs can be interpreted as cross-entropies, meaning + we correct for the entropy of the smoothed distribution. + + Args: + logprobs: + A Tensor of shape (*, num_classes), containing logprobs that sum + to one: e.g. the output of log_softmax. + alpha: + A constant that defines the extent of label smoothing, e.g. 0.1. + + Returns: + modified_logprobs, a Tensor of shape (*, num_classes), containing + "fake" logprobs that will give you label-smoothed probabilities. + """ + assert alpha >= 0.0 and alpha < 1 + if alpha == 0.0: + return logprobs + num_classes = logprobs.shape[-1] + + # We correct for the entropy of the label-smoothed target distribution, so + # the resulting logprobs can be thought of as cross-entropies, which are + # more interpretable. + # + # The expression for entropy below is not quite correct -- it treats + # the target label and the smoothed version of the target label as being + # separate classes -- but this can be thought of as an adjustment + # for the way we compute the likelihood below, which also treats the + # target label and its smoothed version as being separate. + target_entropy = -( + (1 - alpha) * math.log(1 - alpha) + + alpha * math.log(alpha / num_classes) + ) + sum_logprob = logprobs.sum(dim=-1, keepdim=True) + + return ( + logprobs * (1 - alpha) + sum_logprob * (alpha / num_classes) + ) + target_entropy + + class Transducer(nn.Module): """It implements https://arxiv.org/pdf/1211.3711.pdf "Sequence Transduction with Recurrent Neural Networks" @@ -62,6 +119,7 @@ class Transducer(nn.Module): x: torch.Tensor, x_lens: torch.Tensor, y: k2.RaggedTensor, + label_smoothing_factor: float, ) -> torch.Tensor: """ Args: @@ -73,6 +131,8 @@ class Transducer(nn.Module): y: A ragged tensor with 2 axes [utt][label]. It contains labels of each utterance. + label_smoothing_factor: + The factor for label smoothing. Should be in the range [0, 1). Returns: Return the transducer loss. """ @@ -103,6 +163,10 @@ class Transducer(nn.Module): encoder_out_len=x_lens, decoder_out_len=y_lens + 1, ) + # logits is of shape (sum_all_TU, vocab_size) + + log_probs = logits.log_softmax(dim=-1) + log_probs = reverse_label_smoothing(log_probs, label_smoothing_factor) # rnnt_loss requires 0 padded targets # Note: y does not start with SOS @@ -114,12 +178,13 @@ class Transducer(nn.Module): import optimized_transducer loss = optimized_transducer.transducer_loss( - logits=logits, + logits=log_probs, targets=y_padded, logit_lengths=x_lens, target_lengths=y_lens, blank=blank_id, reduction="sum", + from_log_softmax=True, ) return loss diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 694ebf1d5..41f8311ec 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -138,6 +138,13 @@ def get_parser(): "2 means tri-gram", ) + parser.add_argument( + "--label-smoothing-factor", + type=float, + default=0.1, + help="The factor for label smoothing", + ) + return parser @@ -383,7 +390,12 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - loss = model(x=feature, x_lens=feature_lens, y=y) + loss = model( + x=feature, + x_lens=feature_lens, + y=y, + label_smoothing_factor=params.label_smoothing_factor, + ) assert loss.requires_grad == is_training