From e46409e90f53702a562ee755c43c26ed808e1416 Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 20 Jan 2022 11:42:02 +0800 Subject: [PATCH] Update aishell with k2 pruned rnnt loss --- .../ASR/transducer_stateless/beam_search.py | 4 +- .../ASR/transducer_stateless/decode.py | 11 ++-- .../ASR/transducer_stateless/decoder.py | 4 +- .../ASR/transducer_stateless/joiner.py | 17 ++---- egs/aishell/ASR/transducer_stateless/model.py | 56 ++++++++++++------- egs/aishell/ASR/transducer_stateless/train.py | 49 ++++++++++++---- 6 files changed, 90 insertions(+), 51 deletions(-) diff --git a/egs/aishell/ASR/transducer_stateless/beam_search.py b/egs/aishell/ASR/transducer_stateless/beam_search.py index 9ed9b2ad1..f347f552f 100644 --- a/egs/aishell/ASR/transducer_stateless/beam_search.py +++ b/egs/aishell/ASR/transducer_stateless/beam_search.py @@ -73,9 +73,9 @@ def greedy_search( continue # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :] + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) # fmt: on - logits = model.joiner(current_encoder_out, decoder_out) + logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1)) # logits is (1, 1, 1, vocab_size) y = logits.argmax().item() diff --git a/egs/aishell/ASR/transducer_stateless/decode.py b/egs/aishell/ASR/transducer_stateless/decode.py index f27e4cdcf..11228375d 100755 --- a/egs/aishell/ASR/transducer_stateless/decode.py +++ b/egs/aishell/ASR/transducer_stateless/decode.py @@ -128,7 +128,7 @@ def get_params() -> AttributeDict: { # parameters for conformer "feature_dim": 80, - "encoder_out_dim": 512, + "embedding_dim": 256, "subsampling_factor": 4, "attention_dim": 512, "nhead": 8, @@ -145,7 +145,7 @@ def get_encoder_model(params: AttributeDict): # TODO: We can add an option to switch between Conformer and Transformer encoder = Conformer( num_features=params.feature_dim, - output_dim=params.encoder_out_dim, + output_dim=params.vocab_size, subsampling_factor=params.subsampling_factor, d_model=params.attention_dim, nhead=params.nhead, @@ -159,7 +159,7 @@ def get_encoder_model(params: AttributeDict): def get_decoder_model(params: AttributeDict): decoder = Decoder( vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, + embedding_dim=params.embedding_dim, blank_id=params.blank_id, context_size=params.context_size, ) @@ -168,8 +168,9 @@ def get_decoder_model(params: AttributeDict): def get_joiner_model(params: AttributeDict): joiner = Joiner( - input_dim=params.encoder_out_dim, + input_dim=params.vocab_size, output_dim=params.vocab_size, + inner_dim=params.embedding_dim, ) return joiner @@ -408,7 +409,7 @@ def main(): device=device, ) - params.blank_id = graph_compiler.texts_to_ids("")[0][0] + params.blank_id = 0 params.vocab_size = max(lexicon.tokens) + 1 logging.info(params) diff --git a/egs/aishell/ASR/transducer_stateless/decoder.py b/egs/aishell/ASR/transducer_stateless/decoder.py index dca084477..7c9c8201c 100644 --- a/egs/aishell/ASR/transducer_stateless/decoder.py +++ b/egs/aishell/ASR/transducer_stateless/decoder.py @@ -70,6 +70,7 @@ class Decoder(nn.Module): groups=embedding_dim, bias=False, ) + self.output_linear = nn.Linear(embedding_dim, vocab_size) def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ @@ -80,7 +81,7 @@ class Decoder(nn.Module): True to left pad the input. Should be True during training. False to not pad the input. Should be False during inference. Returns: - Return a tensor of shape (N, U, embedding_dim). + Return a tensor of shape (N, U, vocab_size). """ embeding_out = self.embedding(y) if self.context_size > 1: @@ -95,4 +96,5 @@ class Decoder(nn.Module): assert embeding_out.size(-1) == self.context_size embeding_out = self.conv(embeding_out) embeding_out = embeding_out.permute(0, 2, 1) + embeding_out = self.output_linear(embeding_out) return embeding_out diff --git a/egs/aishell/ASR/transducer_stateless/joiner.py b/egs/aishell/ASR/transducer_stateless/joiner.py index 2ef3f1de6..9371aec5a 100644 --- a/egs/aishell/ASR/transducer_stateless/joiner.py +++ b/egs/aishell/ASR/transducer_stateless/joiner.py @@ -19,10 +19,12 @@ import torch.nn as nn class Joiner(nn.Module): - def __init__(self, input_dim: int, output_dim: int): + def __init__(self, input_dim: int, inner_dim: int, output_dim: int): super().__init__() - self.output_linear = nn.Linear(input_dim, output_dim) + self.output_linear = nn.Sequential( + nn.Linear(input_dim, inner_dim), nn.Linear(inner_dim, output_dim) + ) def forward( self, encoder_out: torch.Tensor, decoder_out: torch.Tensor @@ -36,15 +38,8 @@ class Joiner(nn.Module): Returns: Return a tensor of shape (N, T, U, C). """ - assert encoder_out.ndim == decoder_out.ndim == 3 - assert encoder_out.size(0) == decoder_out.size(0) - assert encoder_out.size(2) == decoder_out.size(2) - - encoder_out = encoder_out.unsqueeze(2) - # Now encoder_out is (N, T, 1, C) - - decoder_out = decoder_out.unsqueeze(1) - # Now decoder_out is (N, 1, U, C) + assert encoder_out.ndim == decoder_out.ndim == 4 + assert encoder_out.shape == decoder_out.shape logit = encoder_out + decoder_out logit = torch.tanh(logit) diff --git a/egs/aishell/ASR/transducer_stateless/model.py b/egs/aishell/ASR/transducer_stateless/model.py index 2f0f9a183..d69330368 100644 --- a/egs/aishell/ASR/transducer_stateless/model.py +++ b/egs/aishell/ASR/transducer_stateless/model.py @@ -1,4 +1,4 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -14,15 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Note we use `rnnt_loss` from torchaudio, which exists only in -torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0 -""" import k2 import torch import torch.nn as nn -import torchaudio -import torchaudio.functional from encoder_interface import EncoderInterface from icefall.utils import add_sos @@ -38,6 +32,9 @@ class Transducer(nn.Module): encoder: EncoderInterface, decoder: nn.Module, joiner: nn.Module, + prune_range: int = 5, + lm_scale: float = 0.0, + am_scale: float = 0.0, ): """ Args: @@ -62,6 +59,9 @@ class Transducer(nn.Module): self.encoder = encoder self.decoder = decoder self.joiner = joiner + self.prune_range = prune_range + self.lm_scale = lm_scale + self.am_scale = am_scale def forward( self, @@ -102,24 +102,38 @@ class Transducer(nn.Module): decoder_out = self.decoder(sos_y_padded) - logits = self.joiner(encoder_out, decoder_out) - - # rnnt_loss requires 0 padded targets # Note: y does not start with SOS y_padded = y.pad(mode="constant", padding_value=0) - assert hasattr(torchaudio.functional, "rnnt_loss"), ( - f"Current torchaudio version: {torchaudio.__version__}\n" - "Please install a version >= 0.10.0" + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) + boundary[:, 2] = y_lens + boundary[:, 3] = x_lens + + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + decoder_out, + encoder_out, + y_padded, + blank_id, + lm_only_scale=self.lm_scale, + am_only_scale=self.am_scale, + boundary=boundary, + return_grad=True, ) - loss = torchaudio.functional.rnnt_loss( - logits=logits, - targets=y_padded, - logit_lengths=x_lens, - target_lengths=y_lens, - blank=blank_id, - reduction="sum", + ranges = k2.get_rnnt_prune_ranges( + px_grad, py_grad, boundary, self.prune_range + ) + am_pruned, lm_pruned = k2.do_rnnt_pruning( + encoder_out, decoder_out, ranges ) - return loss + logits = self.joiner(am_pruned, lm_pruned) + + pruned_loss = k2.rnnt_loss_pruned( + logits, y_padded, ranges, blank_id, boundary + ) + + return (-torch.sum(simple_loss), -torch.sum(pruned_loss)) diff --git a/egs/aishell/ASR/transducer_stateless/train.py b/egs/aishell/ASR/transducer_stateless/train.py index 7da8e28a1..31d417c3b 100755 --- a/egs/aishell/ASR/transducer_stateless/train.py +++ b/egs/aishell/ASR/transducer_stateless/train.py @@ -38,7 +38,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam @@ -129,6 +128,28 @@ def get_parser(): "2 means tri-gram", ) + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with lm (output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + return parser @@ -185,18 +206,19 @@ def get_params() -> AttributeDict: "batch_idx_train": 0, "log_interval": 50, "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 + "valid_interval": 800, # parameters for conformer "feature_dim": 80, - "encoder_out_dim": 512, "subsampling_factor": 4, "attention_dim": 512, "nhead": 8, "dim_feedforward": 2048, "num_encoder_layers": 12, "vgg_frontend": False, + # parameters for decoder + "embedding_dim": 256, # parameters for Noam - "warm_step": 80000, # For the 100h subset, use 8k + "warm_step": 30000, "env_info": get_env_info(), } ) @@ -208,7 +230,7 @@ def get_encoder_model(params: AttributeDict): # TODO: We can add an option to switch between Conformer and Transformer encoder = Conformer( num_features=params.feature_dim, - output_dim=params.encoder_out_dim, + output_dim=params.vocab_size, subsampling_factor=params.subsampling_factor, d_model=params.attention_dim, nhead=params.nhead, @@ -222,7 +244,7 @@ def get_encoder_model(params: AttributeDict): def get_decoder_model(params: AttributeDict): decoder = Decoder( vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, + embedding_dim=params.embedding_dim, blank_id=params.blank_id, context_size=params.context_size, ) @@ -231,7 +253,8 @@ def get_decoder_model(params: AttributeDict): def get_joiner_model(params: AttributeDict): joiner = Joiner( - input_dim=params.encoder_out_dim, + input_dim=params.vocab_size, + inner_dim=params.embedding_dim, output_dim=params.vocab_size, ) return joiner @@ -246,6 +269,9 @@ def get_transducer_model(params: AttributeDict): encoder=encoder, decoder=decoder, joiner=joiner, + prune_range=params.prune_range, + lm_scale=params.lm_scale, + am_scale=params.am_scale, ) return model @@ -374,7 +400,8 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - loss = model(x=feature, x_lens=feature_lens, y=y) + simple_loss, pruned_loss = model(x=feature, x_lens=feature_lens, y=y) + loss = simple_loss + pruned_loss assert loss.requires_grad == is_training @@ -383,6 +410,8 @@ def compute_loss( # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() return loss, info @@ -476,7 +505,6 @@ def train_one_epoch( optimizer.zero_grad() loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() if batch_idx % params.log_interval == 0: @@ -555,10 +583,9 @@ def run(rank, world_size, args): graph_compiler = CharCtcTrainingGraphCompiler( lexicon=lexicon, device=device, - oov="", ) - params.blank_id = graph_compiler.texts_to_ids("")[0][0] + params.blank_id = 0 params.vocab_size = max(lexicon.tokens) + 1 logging.info(params)