From 3e3c1a6aee3814ff8c96633ad01fede8b41330cd Mon Sep 17 00:00:00 2001 From: pkufool Date: Fri, 28 Jan 2022 20:05:48 +0800 Subject: [PATCH] update docs --- .../ASR/transducer_stateless/decoder.py | 31 ++++++++++++------- .../ASR/transducer_stateless/model.py | 29 ++++++++++++++--- 2 files changed, 44 insertions(+), 16 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py index aec745a9c..5b54cc0a8 100644 --- a/egs/librispeech/ASR/transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/transducer_stateless/decoder.py @@ -38,7 +38,7 @@ class Decoder(nn.Module): embedding_dim: int, blank_id: int, context_size: int, - backward: bool = False, + use_right_context: bool = False, ): """ Args: @@ -51,6 +51,9 @@ class Decoder(nn.Module): context_size: Number of previous words to use to predict the next word. 1 means bigram; 2 means trigram. n means (n+1)-gram. + use_right_context: + True to use right context, which is usefull to implement a + backward decoder, only used for training. """ super().__init__() self.embedding = nn.Embedding( @@ -62,7 +65,7 @@ class Decoder(nn.Module): assert context_size >= 1, context_size self.context_size = context_size - self.backward = backward + self.use_right_context = use_right_context if context_size > 1: self.conv = nn.Conv1d( in_channels=embedding_dim, @@ -88,14 +91,20 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - # If the input is [sos, a, b, c, d] and output is - # [a, b, c, d, eos], padding left and using kernel-size 2, - # it uses left context. - # If the input is [a, b, c, d, eos] and output is - # [sos, a, b, c, d], padding right and using kernel-size 2, - # it uses right context. - if self.backward: - assert self.context_size == 2 + # Regarding the left or right context we are using, + # if we feed sequence [sos, a, b, c, d] to this decoder, and + # want to predict the sequence [a, b, c, d]. After padding to + # the left with context_size==2, the fed in sequence changes to + # [pad, sos, a, b, c, d], and we use `pad,sos` to predict `a`, + # `sos,a` to predict `b` ..., that is left context. + # if we feed sequence [b, c, d, blk, blk] to this decoder, + # and want to predict the sequence [a, b, c, d]. After padding + # to the right with context_size==2, the fed in sequence changes + # to [b, c, d, blk, blk, pad], and we use `b, c` to predict `a` + # `c,d` to predict `b` ..., that is right context. + # This is tricky and not so straightforward, will find better + # implementation later. + if self.use_right_context: embedding_out = F.pad( embedding_out, pad=(0, self.context_size - 1) ) @@ -107,7 +116,7 @@ class Decoder(nn.Module): # During inference time, there is no need to do extra padding # as we only need one output assert embedding_out.size(-1) == self.context_size - assert self.backward is False + assert self.use_right_context is False embedding_out = self.conv(embedding_out) embedding_out = embedding_out.permute(0, 2, 1) return embedding_out diff --git a/egs/librispeech/ASR/transducer_stateless/model.py b/egs/librispeech/ASR/transducer_stateless/model.py index 1338c4df3..5e415e4b6 100644 --- a/egs/librispeech/ASR/transducer_stateless/model.py +++ b/egs/librispeech/ASR/transducer_stateless/model.py @@ -21,7 +21,7 @@ import torch.nn as nn import torch.nn.functional as F from encoder_interface import EncoderInterface -from icefall.utils import add_eos, add_sos +from icefall.utils import add_sos class Transducer(nn.Module): @@ -124,11 +124,14 @@ class Transducer(nn.Module): blank_id = self.decoder.blank_id sos_y = add_sos(y, sos_id=blank_id) + # sos_y_padded: [B, S + 1], start with SOS. sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + # decoder_out: [B, S + 1, C] decoder_out = self.decoder(sos_y_padded) # Note: y does not start with SOS + # y_padded : [B, S] y_padded = y.pad(mode="constant", padding_value=0) boundary = torch.zeros( @@ -148,33 +151,49 @@ class Transducer(nn.Module): boundary=boundary, return_grad=True, ) + + # ranges : [B, T, prune_range] ranges = k2.get_rnnt_prune_ranges( px_grad, py_grad, boundary, self.prune_range ) # forward loss + # am_pruned : [B, T, prune_range, C] + # lm_pruned : [B, T, prune_range, C] am_pruned, lm_pruned = k2.do_rnnt_pruning( encoder_out, decoder_out, ranges ) + # logits : [B, T, prune_range, C] logits = self.joiner(am_pruned, lm_pruned) + pruned_loss = k2.rnnt_loss_pruned( logits, y_padded.to(torch.int64), ranges, blank_id, boundary ) - eos_y = add_eos(y, eos_id=blank_id) - eos_y_padded = eos_y.pad(mode="constant", padding_value=blank_id) - eos_y_padded = F.pad(eos_y_padded[:, 1:], pad=(0, 1), value=blank_id) + # y_padded shape : [B, S] + # we skip the first symbol(a shift trick for right context), + # so we have to pad 2 blank to the right to make the output shape of + # deocder to be [B, S + 1, C], + # backward_y shape : [B, S + 1] + backward_y = F.pad(y_padded[:, 1:], pad=(0, 2), value=blank_id) # backward loss assert self.backward_decoder is not None assert self.backward_joiner is not None - backward_decoder_out = self.backward_decoder(eos_y_padded) + # backward_decoder_out : [B, S + 1, C] + backward_decoder_out = self.backward_decoder(backward_y) + + # backward_am_pruned : [B, T, prune_range, C] + # backward_lm_pruned : [B, T, prune_range, C] backward_am_pruned, backward_lm_pruned = k2.do_rnnt_pruning( encoder_out, backward_decoder_out, ranges ) + + # backward_logits : [B, T, prune_range, C] backward_logits = self.backward_joiner( backward_am_pruned, backward_lm_pruned ) + backward_pruned_loss = k2.rnnt_loss_pruned( backward_logits, y_padded.to(torch.int64),