diff --git a/egs/librispeech/ASR/transducer_stateless/joiner.py b/egs/librispeech/ASR/transducer_stateless/joiner.py index 2ef3f1de6..9fd9da4f1 100644 --- a/egs/librispeech/ASR/transducer_stateless/joiner.py +++ b/egs/librispeech/ASR/transducer_stateless/joiner.py @@ -22,33 +22,51 @@ class Joiner(nn.Module): def __init__(self, input_dim: int, output_dim: int): super().__init__() + self.input_dim = input_dim + self.output_dim = output_dim self.output_linear = nn.Linear(input_dim, output_dim) def forward( - self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + encoder_out_len: torch.Tensor, + decoder_out_len: torch.Tensor, ) -> torch.Tensor: """ Args: encoder_out: - Output from the encoder. Its shape is (N, T, C). + Output from the encoder. Its shape is (N, T, self.input_dim). decoder_out: - Output from the decoder. Its shape is (N, U, C). + Output from the decoder. Its shape is (N, U, self.input_dim). Returns: - Return a tensor of shape (N, T, U, C). + Return a tensor of shape (sum_all_TU, self.output_dim). """ assert encoder_out.ndim == decoder_out.ndim == 3 assert encoder_out.size(0) == decoder_out.size(0) - assert encoder_out.size(2) == decoder_out.size(2) + assert encoder_out.size(2) == self.input_dim + assert decoder_out.size(2) == self.input_dim - encoder_out = encoder_out.unsqueeze(2) - # Now encoder_out is (N, T, 1, C) + N = encoder_out.size(0) - decoder_out = decoder_out.unsqueeze(1) - # Now decoder_out is (N, 1, U, C) + encoder_out_list = [ + encoder_out[i, : encoder_out_len[i], :] for i in range(N) + ] - logit = encoder_out + decoder_out - logit = torch.tanh(logit) + decoder_out_list = [ + decoder_out[i, : decoder_out_len[i], :] for i in range(N) + ] - output = self.output_linear(logit) + x = [ + e.unsqueeze(1) + d.unsqueeze(0) + for e, d in zip(encoder_out_list, decoder_out_list) + ] - return output + x = [p.reshape(-1, self.input_dim) for p in x] + x = torch.cat(x) + + activations = torch.tanh(x) + + logits = self.output_linear(activations) + + return logits diff --git a/egs/librispeech/ASR/transducer_stateless/model.py b/egs/librispeech/ASR/transducer_stateless/model.py index 2f0f9a183..8cd406df0 100644 --- a/egs/librispeech/ASR/transducer_stateless/model.py +++ b/egs/librispeech/ASR/transducer_stateless/model.py @@ -14,20 +14,71 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Note we use `rnnt_loss` from torchaudio, which exists only in -torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0 -""" +import math + import k2 import torch import torch.nn as nn -import torchaudio -import torchaudio.functional from encoder_interface import EncoderInterface from icefall.utils import add_sos +def reverse_label_smoothing( + logprobs: torch.Tensor, alpha: float +) -> torch.Tensor: + """ + This function is written by Dan. + + Modifies `logprobs` in such a way that if you compute a data probability + using `logprobs`, it will be equivalent to a label-smoothed data probability + with the supplied label-smoothing constant alpha (e.g. alpha=0.1). + This allows us to use `logprobs` in things like RNN-T and CTC and + get a kind of label-smoothed version of those sequence objectives. + + Label smoothing means that if the reference label is i, we convert it + into a distribution with weight (1-alpha) on i, and alpha distributed + equally to all labels (including i itself). + + Note: the output logprobs can be interpreted as cross-entropies, meaning + we correct for the entropy of the smoothed distribution. + + Args: + logprobs: + A Tensor of shape (*, num_classes), containing logprobs that sum + to one: e.g. the output of log_softmax. + alpha: + A constant that defines the extent of label smoothing, e.g. 0.1. + + Returns: + modified_logprobs, a Tensor of shape (*, num_classes), containing + "fake" logprobs that will give you label-smoothed probabilities. + """ + assert alpha >= 0.0 and alpha < 1 + if alpha == 0.0: + return logprobs + num_classes = logprobs.shape[-1] + + # We correct for the entropy of the label-smoothed target distribution, so + # the resulting logprobs can be thought of as cross-entropies, which are + # more interpretable. + # + # The expression for entropy below is not quite correct -- it treats + # the target label and the smoothed version of the target label as being + # separate classes -- but this can be thought of as an adjustment + # for the way we compute the likelihood below, which also treats the + # target label and its smoothed version as being separate. + target_entropy = -( + (1 - alpha) * math.log(1 - alpha) + + alpha * math.log(alpha / num_classes) + ) + sum_logprob = logprobs.sum(dim=-1, keepdim=True) + + return ( + logprobs * (1 - alpha) + sum_logprob * (alpha / num_classes) + ) + target_entropy + + class Transducer(nn.Module): """It implements https://arxiv.org/pdf/1211.3711.pdf "Sequence Transduction with Recurrent Neural Networks" @@ -68,6 +119,7 @@ class Transducer(nn.Module): x: torch.Tensor, x_lens: torch.Tensor, y: k2.RaggedTensor, + label_smoothing_factor: float, ) -> torch.Tensor: """ Args: @@ -79,6 +131,8 @@ class Transducer(nn.Module): y: A ragged tensor with 2 axes [utt][label]. It contains labels of each utterance. + label_smoothing_factor: + The factor for label smoothing. Should be in the range [0, 1). Returns: Return the transducer loss. """ @@ -102,24 +156,35 @@ class Transducer(nn.Module): decoder_out = self.decoder(sos_y_padded) - logits = self.joiner(encoder_out, decoder_out) + # +1 here since a blank is prepended to each utterance. + logits = self.joiner( + encoder_out=encoder_out, + decoder_out=decoder_out, + encoder_out_len=x_lens, + decoder_out_len=y_lens + 1, + ) + # logits is of shape (sum_all_TU, vocab_size) + + log_probs = logits.log_softmax(dim=-1) + log_probs = reverse_label_smoothing(log_probs, label_smoothing_factor) # rnnt_loss requires 0 padded targets # Note: y does not start with SOS y_padded = y.pad(mode="constant", padding_value=0) - assert hasattr(torchaudio.functional, "rnnt_loss"), ( - f"Current torchaudio version: {torchaudio.__version__}\n" - "Please install a version >= 0.10.0" - ) + # We don't put this `import` at the beginning of the file + # as it is required only in the training, not during the + # reference stage + import optimized_transducer - loss = torchaudio.functional.rnnt_loss( - logits=logits, + loss = optimized_transducer.transducer_loss( + logits=log_probs, targets=y_padded, logit_lengths=x_lens, target_lengths=y_lens, blank=blank_id, reduction="sum", + from_log_softmax=True, ) return loss diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 694ebf1d5..41f8311ec 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -138,6 +138,13 @@ def get_parser(): "2 means tri-gram", ) + parser.add_argument( + "--label-smoothing-factor", + type=float, + default=0.1, + help="The factor for label smoothing", + ) + return parser @@ -383,7 +390,12 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - loss = model(x=feature, x_lens=feature_lens, y=y) + loss = model( + x=feature, + x_lens=feature_lens, + y=y, + label_smoothing_factor=params.label_smoothing_factor, + ) assert loss.requires_grad == is_training