From b49510e2bf7064f4f60650e6787288db1bad2941 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 31 Dec 2021 15:52:33 +0800 Subject: [PATCH] Add label smoothing for transducer loss. --- .../ASR/transducer_stateless/model.py | 67 ++++++++++++++++++- .../ASR/transducer_stateless/train.py | 14 +++- 2 files changed, 79 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/model.py b/egs/librispeech/ASR/transducer_stateless/model.py index 98a6f0f37..8cd406df0 100644 --- a/egs/librispeech/ASR/transducer_stateless/model.py +++ b/egs/librispeech/ASR/transducer_stateless/model.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math + import k2 import torch import torch.nn as nn @@ -22,6 +24,61 @@ from encoder_interface import EncoderInterface from icefall.utils import add_sos +def reverse_label_smoothing( + logprobs: torch.Tensor, alpha: float +) -> torch.Tensor: + """ + This function is written by Dan. + + Modifies `logprobs` in such a way that if you compute a data probability + using `logprobs`, it will be equivalent to a label-smoothed data probability + with the supplied label-smoothing constant alpha (e.g. alpha=0.1). + This allows us to use `logprobs` in things like RNN-T and CTC and + get a kind of label-smoothed version of those sequence objectives. + + Label smoothing means that if the reference label is i, we convert it + into a distribution with weight (1-alpha) on i, and alpha distributed + equally to all labels (including i itself). + + Note: the output logprobs can be interpreted as cross-entropies, meaning + we correct for the entropy of the smoothed distribution. + + Args: + logprobs: + A Tensor of shape (*, num_classes), containing logprobs that sum + to one: e.g. the output of log_softmax. + alpha: + A constant that defines the extent of label smoothing, e.g. 0.1. + + Returns: + modified_logprobs, a Tensor of shape (*, num_classes), containing + "fake" logprobs that will give you label-smoothed probabilities. + """ + assert alpha >= 0.0 and alpha < 1 + if alpha == 0.0: + return logprobs + num_classes = logprobs.shape[-1] + + # We correct for the entropy of the label-smoothed target distribution, so + # the resulting logprobs can be thought of as cross-entropies, which are + # more interpretable. + # + # The expression for entropy below is not quite correct -- it treats + # the target label and the smoothed version of the target label as being + # separate classes -- but this can be thought of as an adjustment + # for the way we compute the likelihood below, which also treats the + # target label and its smoothed version as being separate. + target_entropy = -( + (1 - alpha) * math.log(1 - alpha) + + alpha * math.log(alpha / num_classes) + ) + sum_logprob = logprobs.sum(dim=-1, keepdim=True) + + return ( + logprobs * (1 - alpha) + sum_logprob * (alpha / num_classes) + ) + target_entropy + + class Transducer(nn.Module): """It implements https://arxiv.org/pdf/1211.3711.pdf "Sequence Transduction with Recurrent Neural Networks" @@ -62,6 +119,7 @@ class Transducer(nn.Module): x: torch.Tensor, x_lens: torch.Tensor, y: k2.RaggedTensor, + label_smoothing_factor: float, ) -> torch.Tensor: """ Args: @@ -73,6 +131,8 @@ class Transducer(nn.Module): y: A ragged tensor with 2 axes [utt][label]. It contains labels of each utterance. + label_smoothing_factor: + The factor for label smoothing. Should be in the range [0, 1). Returns: Return the transducer loss. """ @@ -103,6 +163,10 @@ class Transducer(nn.Module): encoder_out_len=x_lens, decoder_out_len=y_lens + 1, ) + # logits is of shape (sum_all_TU, vocab_size) + + log_probs = logits.log_softmax(dim=-1) + log_probs = reverse_label_smoothing(log_probs, label_smoothing_factor) # rnnt_loss requires 0 padded targets # Note: y does not start with SOS @@ -114,12 +178,13 @@ class Transducer(nn.Module): import optimized_transducer loss = optimized_transducer.transducer_loss( - logits=logits, + logits=log_probs, targets=y_padded, logit_lengths=x_lens, target_lengths=y_lens, blank=blank_id, reduction="sum", + from_log_softmax=True, ) return loss diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 694ebf1d5..41f8311ec 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -138,6 +138,13 @@ def get_parser(): "2 means tri-gram", ) + parser.add_argument( + "--label-smoothing-factor", + type=float, + default=0.1, + help="The factor for label smoothing", + ) + return parser @@ -383,7 +390,12 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - loss = model(x=feature, x_lens=feature_lens, y=y) + loss = model( + x=feature, + x_lens=feature_lens, + y=y, + label_smoothing_factor=params.label_smoothing_factor, + ) assert loss.requires_grad == is_training