From e4d45adf5a24c2f6b2a9735c0ad3bdac7b402bb8 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 21 Apr 2022 11:01:08 +0800 Subject: [PATCH] Change model.py and joiner.py to use torchaudio's RNN-T loss. --- egs/librispeech/ASR/README.md | 1 + .../ASR/transducer_stateless3/joiner.py | 38 +++++-- .../ASR/transducer_stateless3/model.py | 102 ++++-------------- 3 files changed, 51 insertions(+), 90 deletions(-) diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index de9d6d50a..24b29b653 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -14,6 +14,7 @@ The following table lists the differences among them. | `transducer` | Conformer | LSTM | | | `transducer_stateless` | Conformer | Embedding + Conv1d | Using optimized_transducer from computing RNN-T loss | | `transducer_stateless2` | Conformer | Embedding + Conv1d | Using torchaudio for computing RNN-T loss | +| `transducer_stateless3` | Conformer(modified) | Embedding + Conv1d | Using torchaudio for computing RNN-T loss | | `transducer_lstm` | LSTM | LSTM | | | `transducer_stateless_multi_datasets` | Conformer | Embedding + Conv1d | Using data from GigaSpeech as extra training data | | `pruned_transducer_stateless` | Conformer | Embedding + Conv1d | Using k2 pruned RNN-T loss | diff --git a/egs/librispeech/ASR/transducer_stateless3/joiner.py b/egs/librispeech/ASR/transducer_stateless3/joiner.py index 35f75ed2a..18a6cd845 100644 --- a/egs/librispeech/ASR/transducer_stateless3/joiner.py +++ b/egs/librispeech/ASR/transducer_stateless3/joiner.py @@ -33,6 +33,10 @@ class Joiner(nn.Module): self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim) self.output_linear = ScaledLinear(joiner_dim, vocab_size) + self.encoder_dim = encoder_dim + self.decoder_dim = decoder_dim + self.joiner_dim = joiner_dim + def forward( self, encoder_out: torch.Tensor, @@ -42,9 +46,9 @@ class Joiner(nn.Module): """ Args: encoder_out: - Output from the encoder. Its shape is (N, T, s_range, C). + Output from the encoder. Its shape is (N, T, joiner_dim). decoder_out: - Output from the decoder. Its shape is (N, T, s_range, C). + Output from the decoder. Its shape is (N, U, joiner_dim). project_input: If true, apply input projections encoder_proj and decoder_proj. If this is false, it is the user's responsibility to do this @@ -52,16 +56,30 @@ class Joiner(nn.Module): Returns: Return a tensor of shape (N, T, s_range, C). """ - assert encoder_out.ndim == decoder_out.ndim == 4 - assert encoder_out.shape[:-1] == decoder_out.shape[:-1] + assert encoder_out.ndim == decoder_out.ndim == 3 + assert encoder_out.size(0) == decoder_out.size(0) if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj( - decoder_out - ) + assert encoder_out.size(2) == self.encoder_dim + assert decoder_out.size(2) == self.decoder_dim + encoder_out = self.encoder_proj(encoder_out) + decoder_out = self.decoder_proj(decoder_out) else: - logit = encoder_out + decoder_out + assert encoder_out.size(2) == self.joiner_dim + assert decoder_out.size(2) == self.joiner_dim - logit = self.output_linear(torch.tanh(logit)) + encoder_out = encoder_out.unsqueeze(2) # (N, T, 1, C) + decoder_out = decoder_out.unsqueeze(1) # (N, 1, U, C) + x = encoder_out + decoder_out # (N, T, U, C) - return logit + activations = torch.tanh(x) + + logits = self.output_linear(activations) + + if not self.training: + # We reuse the beam_search.py from transducer_stateless, + # which expects that the joiner network outputs + # a 2-D tensor. + logits = logits.squeeze(2).squeeze(1) + + return logits diff --git a/egs/librispeech/ASR/transducer_stateless3/model.py b/egs/librispeech/ASR/transducer_stateless3/model.py index 599bf2506..7c4881c72 100644 --- a/egs/librispeech/ASR/transducer_stateless3/model.py +++ b/egs/librispeech/ASR/transducer_stateless3/model.py @@ -63,19 +63,11 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) - self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) - def forward( self, x: torch.Tensor, x_lens: torch.Tensor, y: k2.RaggedTensor, - prune_range: int = 5, - am_scale: float = 0.0, - lm_scale: float = 0.0, warmup: float = 1.0, ) -> torch.Tensor: """ @@ -88,26 +80,11 @@ class Transducer(nn.Module): y: A ragged tensor with 2 axes [utt][label]. It contains labels of each utterance. - prune_range: - The prune range for rnnt loss, it means how many symbols(context) - we are considering for each frame to compute the loss. - am_scale: - The scale to smooth the loss with am (output of encoder network) - part - lm_scale: - The scale to smooth the loss with lm (output of predictor network) - part warmup: A value warmup >= 0 that determines which modules are active, values warmup > 1 "are fully warmed up" and all modules will be active. Returns: Return the transducer loss. - - Note: - Regarding am_scale & lm_scale, it will make the loss-function one of - the form: - lm_scale * lm_probs + am_scale * am_probs + - (1-lm_scale-am_scale) * combined_probs """ assert x.ndim == 3, x.shape assert x_lens.ndim == 1, x_lens.shape @@ -115,8 +92,8 @@ class Transducer(nn.Module): assert x.size(0) == x_lens.size(0) == y.dim0 - encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup) - assert torch.all(x_lens > 0) + encoder_out, encoder_out_lens = self.encoder(x, x_lens, warmup=warmup) + assert torch.all(encoder_out_lens > 0) # Now for the decoder, i.e., the prediction network row_splits = y.shape.row_splits(1) @@ -125,69 +102,34 @@ class Transducer(nn.Module): blank_id = self.decoder.blank_id sos_y = add_sos(y, sos_id=blank_id) - # sos_y_padded: [B, S + 1], start with SOS. + # sos_y_padded: [B, U + 1], start with SOS. sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) - # decoder_out: [B, S + 1, decoder_dim] + # decoder_out: [B, U + 1, decoder_dim] decoder_out = self.decoder(sos_y_padded) + logits = self.joiner( + encoder_out=encoder_out, + decoder_out=decoder_out, + project_input=True, + ) + # Note: y does not start with SOS - # y_padded : [B, S] + # y_padded : [B, U] y_padded = y.pad(mode="constant", padding_value=0) - y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) - boundary[:, 2] = y_lens - boundary[:, 3] = x_lens - - lm = self.simple_lm_proj(decoder_out) - am = self.simple_am_proj(encoder_out) - - with torch.cuda.amp.autocast(enabled=False): - simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( - lm=lm.float(), - am=am.float(), - symbols=y_padded, - termination_symbol=blank_id, - lm_only_scale=lm_scale, - am_only_scale=am_scale, - boundary=boundary, - reduction="sum", - return_grad=True, - ) - - # ranges : [B, T, prune_range] - ranges = k2.get_rnnt_prune_ranges( - px_grad=px_grad, - py_grad=py_grad, - boundary=boundary, - s_range=prune_range, + assert hasattr(torchaudio.functional, "rnnt_loss"), ( + f"Current torchaudio version: {torchaudio.__version__}\n" + "Please install a version >= 0.10.0" ) - # am_pruned : [B, T, prune_range, encoder_dim] - # lm_pruned : [B, T, prune_range, decoder_dim] - am_pruned, lm_pruned = k2.do_rnnt_pruning( - am=self.joiner.encoder_proj(encoder_out), - lm=self.joiner.decoder_proj(decoder_out), - ranges=ranges, + loss = torchaudio.functional.rnnt_loss( + logits=logits, + targets=y_padded, + logit_lengths=encoder_out_lens, + target_lengths=y_lens, + blank=blank_id, + reduction="sum", ) - # logits : [B, T, prune_range, vocab_size] - - # project_input=False since we applied the decoder's input projections - # prior to do_rnnt_pruning (this is an optimization for speed). - logits = self.joiner(am_pruned, lm_pruned, project_input=False) - - with torch.cuda.amp.autocast(enabled=False): - pruned_loss = k2.rnnt_loss_pruned( - logits=logits.float(), - symbols=y_padded, - ranges=ranges, - termination_symbol=blank_id, - boundary=boundary, - reduction="sum", - ) - - return (simple_loss, pruned_loss) + return loss