From feb526c2a40e0919a2f4cf9cde9d35eb442696ca Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 31 Mar 2022 20:12:41 +0800 Subject: [PATCH] Predicting blanks via gradients from the trivial joiner. --- .../blank_predictor.py | 65 +++++++++++++++++++ .../pruned_transducer_stateless-2/model.py | 36 ++++++++-- .../test_blank_predictor.py | 43 ++++++++++++ .../pruned_transducer_stateless-2/train.py | 30 +++++++-- 4 files changed, 163 insertions(+), 11 deletions(-) create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless-2/blank_predictor.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless-2/test_blank_predictor.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless-2/blank_predictor.py b/egs/librispeech/ASR/pruned_transducer_stateless-2/blank_predictor.py new file mode 100644 index 000000000..03f82700b --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless-2/blank_predictor.py @@ -0,0 +1,65 @@ +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + +from icefall.utils import make_pad_mask + + +class BlankPredictor(nn.Module): + def __init__(self, encoder_out_dim: int): + """ + Args: + Output dimension of the encoder network. + """ + super().__init__() + self.linear = nn.Linear(in_features=encoder_out_dim, out_features=1) + + self.loss_func = nn.BCEWithLogitsLoss(reduction="none") + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + soft_target: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, encoder_out_dim) from the output of + the encoder network. + x_lens: + A 1-D tensor of shape (N,) containing the number of valid frames + for each element in `x`. + soft_target: + A 2-D tensor of shape (N, T) containing the soft label of each frame + in `x`. + """ + assert x.ndim == 3, x.shape + assert soft_target.ndim == 2, soft_target.shape + + assert x.shape[:2] == soft_target.shape[:2], ( + x.shape, + soft_target.shape, + ) + logits = self.linear(x).squeeze(-1) + mask = make_pad_mask(x_lens) + + loss = self.loss_func(logits, soft_target) + loss.masked_fill_(mask, 0) + + return loss.sum() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless-2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless-2/model.py index 2f019bcdb..f0c08feb6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless-2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless-2/model.py @@ -15,6 +15,8 @@ # limitations under the License. +from typing import Tuple + import k2 import torch import torch.nn as nn @@ -33,6 +35,7 @@ class Transducer(nn.Module): encoder: EncoderInterface, decoder: nn.Module, joiner: nn.Module, + blank_predictor: nn.Module, ): """ Args: @@ -49,6 +52,9 @@ class Transducer(nn.Module): It has two inputs with shapes: (N, T, C) and (N, U, C). Its output shape is (N, T, U, C). Note that its output contains unnormalized probs, i.e., not processed by log-softmax. + blank_predictor: + The model to predict blanks from the encoder output. See also + `./blank_predictor.py`. """ super().__init__() assert isinstance(encoder, EncoderInterface), type(encoder) @@ -57,6 +63,7 @@ class Transducer(nn.Module): self.encoder = encoder self.decoder = decoder self.joiner = joiner + self.blank_predictor = blank_predictor def forward( self, @@ -66,7 +73,7 @@ class Transducer(nn.Module): prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: x: @@ -87,7 +94,11 @@ class Transducer(nn.Module): The scale to smooth the loss with lm (output of predictor network) part Returns: - Return the transducer loss. + Return a tuple containing: + + - The loss for the "trivial" joiner + - The loss for the non-linear joiner + - The loss for predicting the blank token Note: Regarding am_scale & lm_scale, it will make the loss-function one of @@ -101,8 +112,8 @@ class Transducer(nn.Module): assert x.size(0) == x_lens.size(0) == y.dim0 - encoder_out, x_lens = self.encoder(x, x_lens) - assert torch.all(x_lens > 0) + encoder_out, encoder_out_lens = self.encoder(x, x_lens) + assert torch.all(encoder_out_lens > 0) # Now for the decoder, i.e., the prediction network row_splits = y.shape.row_splits(1) @@ -126,7 +137,7 @@ class Transducer(nn.Module): (x.size(0), 4), dtype=torch.int64, device=x.device ) boundary[:, 2] = y_lens - boundary[:, 3] = x_lens + boundary[:, 3] = encoder_out_lens simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=decoder_out, @@ -139,6 +150,19 @@ class Transducer(nn.Module): reduction="sum", return_grad=True, ) + # + # px_grad shape: (B, y_lens.max(), T+1) + # Note: In the paper, we use y'(t, u) + # + non_blank_occuptation = px_grad[:, :, :-1].sum(dim=1) + non_blank_occuptation = torch.clamp(non_blank_occuptation, min=0, max=1) + blank_occupation = 1 - non_blank_occuptation + + blank_prediction_loss = self.blank_predictor( + encoder_out, + encoder_out_lens, + blank_occupation, + ) # ranges : [B, T, prune_range] ranges = k2.get_rnnt_prune_ranges( @@ -166,4 +190,4 @@ class Transducer(nn.Module): reduction="sum", ) - return (simple_loss, pruned_loss) + return (simple_loss, pruned_loss, blank_prediction_loss) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless-2/test_blank_predictor.py b/egs/librispeech/ASR/pruned_transducer_stateless-2/test_blank_predictor.py new file mode 100755 index 000000000..93393e3a0 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless-2/test_blank_predictor.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./pruned_transducer_stateless_2/test_blank_predictor.py +""" +import torch +from blank_predictor import BlankPredictor + + +def test_blank_predictor(): + dim = 10 + predictor = BlankPredictor(encoder_out_dim=dim) + x = torch.rand(4, 3, dim) + x_lens = torch.tensor([1, 3, 2, 3], dtype=torch.int32) + y = torch.rand(4, 3) + loss = predictor(x, x_lens, y) + print(loss) + + +def main(): + test_blank_predictor() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless-2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless-2/train.py index 17f82e601..d46330bb8 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless-2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless-2/train.py @@ -21,11 +21,11 @@ Usage: export CUDA_VISIBLE_DEVICES="0,1,2,3" -./pruned_transducer_stateless/train.py \ +./pruned_transducer_stateless-2/train.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 0 \ - --exp-dir pruned_transducer_stateless/exp \ + --exp-dir pruned_transducer_stateless-2/exp \ --full-libri 1 \ --max-duration 300 """ @@ -44,6 +44,7 @@ import torch import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule +from blank_predictor import BlankPredictor from conformer import Conformer from decoder import Decoder from joiner import Joiner @@ -128,7 +129,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless/exp", + default="pruned_transducer_stateless-2/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -191,6 +192,13 @@ def get_parser(): "with this parameter before adding to the final loss.", ) + parser.add_argument( + "--blank-prediction-scale", + type=float, + default=0.1, + help="Scale to use for the blank prediction loss", + ) + parser.add_argument( "--seed", type=int, @@ -333,15 +341,22 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: return joiner +def get_blank_prediction_model(params: AttributeDict) -> nn.Module: + blank_predictor = BlankPredictor(encoder_out_dim=params.vocab_size) + return blank_predictor + + def get_transducer_model(params: AttributeDict) -> nn.Module: encoder = get_encoder_model(params) decoder = get_decoder_model(params) joiner = get_joiner_model(params) + blank_predictor = get_blank_prediction_model(params) model = Transducer( encoder=encoder, decoder=decoder, joiner=joiner, + blank_predictor=blank_predictor, ) return model @@ -484,7 +499,7 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( + simple_loss, pruned_loss, blank_prediction_loss = model( x=feature, x_lens=feature_lens, y=y, @@ -492,7 +507,11 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) - loss = params.simple_loss_scale * simple_loss + pruned_loss + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss + + params.blank_prediction_scale * blank_prediction_loss + ) assert loss.requires_grad == is_training @@ -507,6 +526,7 @@ def compute_loss( info["loss"] = loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item() + info["blank_prediction_loss"] = blank_prediction_loss.detach().cpu().item() return loss, info