From 8653b6a68a7f732a212fc5e65dfa0520a08cae73 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 7 Feb 2022 12:09:26 +0800 Subject: [PATCH 1/4] Apply random frame shift along the time axis. --- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 28 ++++++- .../ASR/transducer_stateless/frame_shift.py | 84 +++++++++++++++++++ .../transducer_stateless/test_frame_shift.py | 70 ++++++++++++++++ .../ASR/transducer_stateless/train.py | 20 ++++- 4 files changed, 200 insertions(+), 2 deletions(-) create mode 100644 egs/librispeech/ASR/transducer_stateless/frame_shift.py create mode 100755 egs/librispeech/ASR/transducer_stateless/test_frame_shift.py diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index e075a2d03..ba9e08569 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -19,7 +19,9 @@ import argparse import logging from functools import lru_cache from pathlib import Path +from typing import Callable, List, Optional +import torch from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse.dataset import ( BucketingSampler, @@ -179,7 +181,27 @@ class LibriSpeechAsrDataModule: "with training dataset. ", ) - def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: + def train_dataloaders( + self, + cuts_train: CutSet, + extra_input_transforms: Optional[ + List[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] + ], + ) -> DataLoader: + """ + Args: + cuts_train: + The cutset for training. + extra_input_transforms: + The extra input transforms that will be applied after all input + transforms, e.g., after SpecAugment if there is any. + Each input transform accepts two input arguments: + - A 3-D torch.Tensor of shape (N, T, C) + - A 2-D torch.Tensor of shape (num_seqs, 3), where the + first column is `sequence_idx`, the second column is + `start_frame`, and the third column is `num_frames`. + and returns a 3-D torch.Tensor of shape (N, T, C). + """ logging.info("About to get Musan cuts") cuts_musan = load_manifest( self.args.manifest_dir / "cuts_musan.json.gz" @@ -228,6 +250,10 @@ class LibriSpeechAsrDataModule: else: logging.info("Disable SpecAugment") + if extra_input_transforms is not None: + input_transforms += extra_input_transforms + logging.info(f"Input transforms: {input_transforms}") + logging.info("About to create train dataset") train = K2SpeechRecognitionDataset( cut_transforms=transforms, diff --git a/egs/librispeech/ASR/transducer_stateless/frame_shift.py b/egs/librispeech/ASR/transducer_stateless/frame_shift.py new file mode 100644 index 000000000..f574bd74b --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless/frame_shift.py @@ -0,0 +1,84 @@ +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from lhotse.utils import LOG_EPSILON + + +def apply_frame_shift( + features: torch.Tensor, + supervision_segments: torch.Tensor, +) -> torch.Tensor: + """Apply random frame shift along the time axis. + + For instance, for the input frame `[a, b, c, d]`, + + - If frame shift is 0, the resulting output is `[a, b, c, d]` + - If frame shift is -1, the resulting output is `[b, c, d, a]` + - If frame shift is 1, the resulting output is `[d, a, b, c]` + - If frame shift is 2, the resulting output is `[c, d, a, b]` + + Args: + features: + A 3-D tensor of shape (N, T, C). + supervision_segments: + A 2-D tensor of shape (num_seqs, 3). The first column is + `sequence_idx`, the second column is `start_frame`, and + the third column is `num_frames`. + Returns: + Return a 3-D tensor of shape (N, T, C). + """ + # We assume the subsampling_factor is 4. If you change the + # subsampling_factor, you should also change the following + # list accordingly + # + # The value in frame_shifts is selected in such a way that + # "value % subsampling_factor" is not duplicated in frame_shifts. + frame_shifts = [-1, 0, 1, 2] + + N = features.size(0) + + # We don't support cut concatenation here + assert torch.all( + torch.eq(supervision_segments[:, 0], torch.arange(N)) + ), supervision_segments + + ans = [] + for i in range(N): + start = supervision_segments[i, 1] + end = start + supervision_segments[i, 2] + + feat = features[i, start:end, :] + + r = torch.randint(low=0, high=len(frame_shifts), size=(1,)).item() + frame_shift = frame_shifts[r] + + # You can enable the following debug statement + # and run ./transducer_stateless/test_frame_shift.py to + # view the debug output. + # print("frame_shift", frame_shift) + + feat = torch.roll(feat, shifts=frame_shift, dims=0) + ans.append(feat) + + ans = torch.nn.utils.rnn.pad_sequence( + ans, + batch_first=True, + padding_value=LOG_EPSILON, + ) + assert features.shape == ans.shape + + return ans diff --git a/egs/librispeech/ASR/transducer_stateless/test_frame_shift.py b/egs/librispeech/ASR/transducer_stateless/test_frame_shift.py new file mode 100755 index 000000000..ca1054a63 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless/test_frame_shift.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./transducer_stateless/test_frame_shift.py +""" + +import torch +from frame_shift import apply_frame_shift + + +def test_apply_frame_shift(): + features = torch.tensor( + [ + [ + [1, 2, 5], + [2, 6, 9], + [3, 0, 2], + [4, 11, 13], + [0, 0, 0], + [0, 0, 0], + ], + [ + [1, 3, 9], + [2, 5, 8], + [3, 3, 6], + [4, 0, 3], + [5, 1, 2], + [6, 6, 6], + ], + ] + ) + supervision_segments = torch.tensor( + [ + [0, 0, 4], + [1, 0, 6], + ], + dtype=torch.int32, + ) + shifted_features = apply_frame_shift(features, supervision_segments) + + # You can enable the debug statement in frame_shift.py + # and check the resulting shifted_features. I've verified + # manually that it is correct. + print(shifted_features) + + +def main(): + test_apply_frame_shift() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 950a88a35..0915bda0f 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -46,6 +46,7 @@ import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer from decoder import Decoder +from frame_shift import apply_frame_shift from joiner import Joiner from lhotse.cut import Cut from lhotse.utils import fix_random_seed @@ -138,6 +139,13 @@ def get_parser(): "2 means tri-gram", ) + parser.add_argument( + "--apply-frame-shift", + type=str2bool, + default=False, + help="If enabled, apply random frame shift along the time axis", + ) + return parser @@ -620,7 +628,17 @@ def run(rank, world_size, args): logging.info(f"After removing short and long utterances: {num_left}") logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") - train_dl = librispeech.train_dataloaders(train_cuts) + if params.apply_frame_shift: + logging.info("Enable random frame shift") + extra_input_transforms = [apply_frame_shift] + else: + logging.info("Disable random frame shift") + extra_input_transforms = None + + train_dl = librispeech.train_dataloaders( + train_cuts, + extra_input_transforms=extra_input_transforms, + ) valid_cuts = librispeech.dev_clean_cuts() valid_cuts += librispeech.dev_other_cuts() From f2a45eb38d1605c3400c9e9ef9b9eddd1126163f Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 7 Feb 2022 19:02:47 +0800 Subject: [PATCH 2/4] Remove learnable offset, use relu instead. See https://github.com/k2-fsa/icefall/pull/199 --- .../ASR/transducer_stateless/conformer.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 81d7708f9..d1a28ccd9 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -629,9 +629,11 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = ( + nn.functional.linear(query, in_proj_weight, in_proj_bias) + .relu() + .chunk(3, dim=-1) + ) elif torch.equal(key, value): # encoder-decoder attention @@ -642,7 +644,7 @@ class RelPositionMultiheadAttention(nn.Module): _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) + q = nn.functional.linear(query, _w, _b).relu() # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim @@ -650,7 +652,7 @@ class RelPositionMultiheadAttention(nn.Module): _w = in_proj_weight[_start:, :] if _b is not None: _b = _b[_start:] - k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) + k, v = nn.functional.linear(key, _w, _b).relu().chunk(2, dim=-1) else: # This is inline in_proj function with in_proj_weight and in_proj_bias @@ -660,7 +662,7 @@ class RelPositionMultiheadAttention(nn.Module): _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) + q = nn.functional.linear(query, _w, _b).relu() # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias @@ -669,7 +671,7 @@ class RelPositionMultiheadAttention(nn.Module): _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] - k = nn.functional.linear(key, _w, _b) + k = nn.functional.linear(key, _w, _b).relu() # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias @@ -678,7 +680,7 @@ class RelPositionMultiheadAttention(nn.Module): _w = in_proj_weight[_start:, :] if _b is not None: _b = _b[_start:] - v = nn.functional.linear(value, _w, _b) + v = nn.functional.linear(value, _w, _b).relu() if attn_mask is not None: assert ( From 09bbed327572a05a5181792f1d4d72994e4c83ba Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 7 Feb 2022 19:10:21 +0800 Subject: [PATCH 3/4] Use CTC loss as auxiliary loss. See https://github.com/k2-fsa/icefall/pull/186 --- .../ASR/transducer_stateless/model.py | 11 ++- .../ASR/transducer_stateless/train.py | 79 ++++++++++++++++++- 2 files changed, 85 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/model.py b/egs/librispeech/ASR/transducer_stateless/model.py index 8281e1fb5..4e4f9d13d 100644 --- a/egs/librispeech/ASR/transducer_stateless/model.py +++ b/egs/librispeech/ASR/transducer_stateless/model.py @@ -79,7 +79,10 @@ class Transducer(nn.Module): modified_transducer_prob: The probability to use modified transducer loss. Returns: - Return the transducer loss. + Return a tuple containing: + - the transducer loss, a tensor containing only one entry + - encoder_out, a tensor of shape (N, num_frames, encoder_out_dim) + - encoder_out_lens, a tensor of shape (N,) """ assert x.ndim == 3, x.shape assert x_lens.ndim == 1, x_lens.shape @@ -140,4 +143,8 @@ class Transducer(nn.Module): from_log_softmax=False, ) - return loss + return ( + loss, + encoder_out, + x_lens, + ) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index c0b1b3a42..0ceb523e7 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -157,6 +157,14 @@ def get_parser(): help="If enabled, apply random frame shift along the time axis", ) + parser.add_argument( + "--ctc-weight", + type=float, + default=0.25, + help="""If not zero, the total loss is: + (1 - ctc_weight) * transdcuder_loss + ctc_weight * ctc_loss + """, + ) return parser @@ -225,6 +233,13 @@ def get_params() -> AttributeDict: "vgg_frontend": False, # parameters for Noam "warm_step": 80000, # For the 100h subset, use 8k + # + # parameters for ctc_loss, used only when ctc_weight > 0 + "modified_ctc_topo": False, + "beam_size": 10, + "reduction": "sum", + "use_double_scores": True, + # "env_info": get_env_info(), } ) @@ -278,6 +293,17 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: return model +def get_ctc_model(params: AttributeDict): + if params.ctc_weight > 0: + return nn.Sequential( + nn.Dropout(p=0.1), + nn.Linear(params.encoder_out_dim, params.vocab_size), + nn.LogSoftmax(dim=-1), + ) + else: + return None + + def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, @@ -398,16 +424,55 @@ def compute_loss( feature_lens = supervisions["num_frames"].to(device) texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y).to(device) + token_ids = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(token_ids).to(device) with torch.set_grad_enabled(is_training): - loss = model( + transducer_loss, encoder_out, encoder_out_lens = model( x=feature, x_lens=feature_lens, y=y, modified_transducer_prob=params.modified_transducer_prob, ) + loss = transducer_loss + + if params.ctc_weight > 0: + ctc_model = ( + model.module.ctc if hasattr(model, "module") else model.ctc + ) + ctc_graph = k2.ctc_graph( + token_ids, modified=params.modified_ctc_topo, device=device + ) + # Note: We assume `encoder_out_lens` is sorted in descending order. + # If not, it will throw in k2.ctc_loss(). + supervision_segments = torch.stack( + [ + torch.arange(encoder_out.size(0), dtype=torch.int32), + torch.zeros(encoder_out.size(0), dtype=torch.int32), + encoder_out_lens.cpu(), + ], + dim=1, + ).to(torch.int32) + nnet_out = ctc_model(encoder_out) + + dense_fsa_vec = k2.DenseFsaVec( + nnet_out, + supervision_segments, + allow_truncate=0, + ) + + # Note: transducer_loss should use the same reduction as ctc_loss + ctc_loss = k2.ctc_loss( + decoding_graph=ctc_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + assert ctc_loss.requires_grad == is_training + loss = ( + 1 - params.ctc_weight + ) * transducer_loss + params.ctc_weight * ctc_loss assert loss.requires_grad == is_training @@ -416,6 +481,9 @@ def compute_loss( # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() + info["transducer_loss"] = transducer_loss.detach().cpu().item() + if params.ctc_weight > 0: + info["ctc_loss"] = ctc_loss.detach().cpu().item() return loss, info @@ -598,6 +666,11 @@ def run(rank, world_size, args): logging.info("About to create model") model = get_transducer_model(params) + model.ctc = get_ctc_model(params) + if model.ctc is not None: + logging.info(f"Enable CTC loss with weight: {params.ctc_weight}") + else: + logging.info("Disable CTC loss") num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") From 136c03d040ad8d81ae8d0ccf5a3a5d9d11d9c79c Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 9 Feb 2022 12:15:12 +0800 Subject: [PATCH 4/4] Fix decoding. --- egs/librispeech/ASR/transducer_stateless/decode.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py index c101d9397..0c07ba308 100755 --- a/egs/librispeech/ASR/transducer_stateless/decode.py +++ b/egs/librispeech/ASR/transducer_stateless/decode.py @@ -441,7 +441,9 @@ def main(): filenames.append(f"{params.exp_dir}/epoch-{i}.pt") logging.info(f"averaging {filenames}") model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) model.to(device) model.eval()