From a432e356a5ef23b613a859da75d7b4f9f872ab0e Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 17 Feb 2022 12:47:17 +0800 Subject: [PATCH] Minor fixes --- egs/librispeech/ASR/RESULTS.md | 48 +++++++++++++++ .../ASR/pruned_transducer_stateless/decode.py | 12 ++-- .../ASR/pruned_transducer_stateless/export.py | 4 +- .../ASR/pruned_transducer_stateless/joiner.py | 7 +-- .../ASR/pruned_transducer_stateless/model.py | 58 ++++++++++--------- .../ASR/pruned_transducer_stateless/train.py | 28 ++++++--- 6 files changed, 112 insertions(+), 45 deletions(-) diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index ffeaaae68..a51a0208f 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -1,5 +1,53 @@ ## Results +### LibriSpeech BPE training results (Pruned Transducer) + +#### Conformer encoder + embedding decoder + +Conformer encoder + non-current decoder. The decoder +contains only an embedding layer, a Conv1d (with kernel size 2) and a linear +layer (to transform tensor dim). + +The WERs are + +| | test-clean | test-other | comment | +|---------------------------|------------|------------|------------------------------------------| +| greedy search | 2.85 | 6.98 | --epoch 28, --avg 15, --max-duration 100 | + +The training command for reproducing is given below: + +``` +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 0 \ + --exp-dir pruned_transducer_stateless/exp \ + --full-libri 1 \ + --max-duration 300 \ + --prune-range 5 \ + --lr-factor 5 \ + --lm-scale 0.25 \ +``` + +The tensorboard training log can be found at + + +The decoding command is: +``` +epoch=28 +avg=15 + +## greedy search +./pruned_transducer_stateless/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir pruned_transducer_stateless/exp \ + --max-duration 100 +``` + + ### LibriSpeech BPE training results (Transducer) #### Conformer encoder + embedding decoder diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 82da8f076..9479d57a8 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -19,16 +19,16 @@ Usage: (1) greedy search ./pruned_transducer_stateless/decode.py \ - --epoch 14 \ - --avg 7 \ + --epoch 28 \ + --avg 15 \ --exp-dir ./pruned_transducer_stateless/exp \ --max-duration 100 \ --decoding-method greedy_search (2) beam search ./pruned_transducer_stateless/decode.py \ - --epoch 14 \ - --avg 7 \ + --epoch 28 \ + --avg 15 \ --exp-dir ./pruned_transducer_stateless/exp \ --max-duration 100 \ --decoding-method beam_search \ @@ -70,14 +70,14 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=29, + default=28, help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, - default=13, + default=15, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export.py b/egs/librispeech/ASR/pruned_transducer_stateless/export.py index c653cf3fc..94987c39a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export.py @@ -68,7 +68,7 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=20, + default=28, help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) @@ -76,7 +76,7 @@ def get_parser(): parser.add_argument( "--avg", type=int, - default=10, + default=15, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless/joiner.py index be3e373a0..7c5a93a86 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/joiner.py @@ -16,6 +16,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F class Joiner(nn.Module): @@ -42,10 +43,8 @@ class Joiner(nn.Module): logit = encoder_out + decoder_out - logit = self.inner_linear(logit) + logit = self.inner_linear(torch.tanh(logit)) - logit = torch.tanh(logit) - - output = self.output_linear(logit) + output = self.output_linear(F.relu(logit)) return output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/model.py b/egs/librispeech/ASR/pruned_transducer_stateless/model.py index 4243ce418..2f019bcdb 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/model.py @@ -33,9 +33,6 @@ class Transducer(nn.Module): encoder: EncoderInterface, decoder: nn.Module, joiner: nn.Module, - prune_range: int = 3, - am_scale: float = 0.0, - lm_scale: float = 0.0, ): """ Args: @@ -52,21 +49,6 @@ class Transducer(nn.Module): It has two inputs with shapes: (N, T, C) and (N, U, C). Its output shape is (N, T, U, C). Note that its output contains unnormalized probs, i.e., not processed by log-softmax. - prune_range: - The prune range for rnnt loss, it means how many symbols(context) - we are considering for each frame to compute the loss. - am_scale: - The scale to smooth the loss with am (output of encoder network) - part - lm_scale: - The scale to smooth the loss with lm (output of predictor network) - part - - Note: - Regarding am_scale & lm_scale, it will make the loss-function one of - the form: - lm_scale * lm_probs + am_scale * am_probs + - (1-lm_scale-am_scale) * combined_probs """ super().__init__() assert isinstance(encoder, EncoderInterface), type(encoder) @@ -75,15 +57,15 @@ class Transducer(nn.Module): self.encoder = encoder self.decoder = decoder self.joiner = joiner - self.prune_range = prune_range - self.lm_scale = lm_scale - self.am_scale = am_scale def forward( self, x: torch.Tensor, x_lens: torch.Tensor, y: k2.RaggedTensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, ) -> torch.Tensor: """ Args: @@ -95,8 +77,23 @@ class Transducer(nn.Module): y: A ragged tensor with 2 axes [utt][label]. It contains labels of each utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part Returns: Return the transducer loss. + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs """ assert x.ndim == 3, x.shape assert x_lens.ndim == 1, x_lens.shape @@ -114,11 +111,14 @@ class Transducer(nn.Module): blank_id = self.decoder.blank_id sos_y = add_sos(y, sos_id=blank_id) + # sos_y_padded: [B, S + 1], start with SOS. sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + # decoder_out: [B, S + 1, C] decoder_out = self.decoder(sos_y_padded) # Note: y does not start with SOS + # y_padded : [B, S] y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) @@ -133,31 +133,37 @@ class Transducer(nn.Module): am=encoder_out, symbols=y_padded, termination_symbol=blank_id, - lm_only_scale=self.lm_scale, - am_only_scale=self.am_scale, + lm_only_scale=lm_scale, + am_only_scale=am_scale, boundary=boundary, + reduction="sum", return_grad=True, ) + # ranges : [B, T, prune_range] ranges = k2.get_rnnt_prune_ranges( px_grad=px_grad, py_grad=py_grad, boundary=boundary, - s_range=self.prune_range, + s_range=prune_range, ) + # am_pruned : [B, T, prune_range, C] + # lm_pruned : [B, T, prune_range, C] am_pruned, lm_pruned = k2.do_rnnt_pruning( am=encoder_out, lm=decoder_out, ranges=ranges ) + # logits : [B, T, prune_range, C] logits = self.joiner(am_pruned, lm_pruned) pruned_loss = k2.rnnt_loss_pruned( - joint=logits, + logits=logits, symbols=y_padded, ranges=ranges, termination_symbol=blank_id, boundary=boundary, + reduction="sum", ) - return (-torch.sum(simple_loss), -torch.sum(pruned_loss)) + return (simple_loss, pruned_loss) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index abd91d33f..e19473788 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -148,7 +148,7 @@ def get_parser(): parser.add_argument( "--prune-range", type=int, - default=3, + default=5, help="The prune range for rnnt loss, it means how many symbols(context)" "we are using to compute the loss", ) @@ -156,7 +156,7 @@ def get_parser(): parser.add_argument( "--lm-scale", type=float, - default=0.5, + default=0.25, help="The scale to smooth the loss with lm " "(output of prediction network) part.", ) @@ -169,6 +169,16 @@ def get_parser(): "part.", ) + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + return parser @@ -289,9 +299,6 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: encoder=encoder, decoder=decoder, joiner=joiner, - prune_range=params.prune_range, - lm_scale=params.lm_scale, - am_scale=params.am_scale, ) return model @@ -420,8 +427,15 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model(x=feature, x_lens=feature_lens, y=y) - loss = simple_loss + pruned_loss + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + loss = params.simple_loss_scale * simple_loss + pruned_loss assert loss.requires_grad == is_training