From 026f446a4d830f1a6f1b9441bed327f4e0bf6c9d Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 28 Apr 2022 14:13:26 +0800 Subject: [PATCH] Use k2 pruned RNN-T. --- .../ASR/transducer_lstm/beam_search.py | 2 +- egs/librispeech/ASR/transducer_lstm/decode.py | 178 ++++++- .../ASR/transducer_lstm/decoder.py | 99 +--- egs/librispeech/ASR/transducer_lstm/joiner.py | 58 +-- egs/librispeech/ASR/transducer_lstm/model.py | 127 +---- egs/librispeech/ASR/transducer_lstm/optim.py | 1 + .../ASR/transducer_lstm/scaling.py | 1 + .../{test_encoder.py => test_model.py} | 32 +- egs/librispeech/ASR/transducer_lstm/train.py | 486 +++++++++++++----- 9 files changed, 529 insertions(+), 455 deletions(-) mode change 100644 => 120000 egs/librispeech/ASR/transducer_lstm/decoder.py mode change 100644 => 120000 egs/librispeech/ASR/transducer_lstm/joiner.py mode change 100644 => 120000 egs/librispeech/ASR/transducer_lstm/model.py create mode 120000 egs/librispeech/ASR/transducer_lstm/optim.py create mode 120000 egs/librispeech/ASR/transducer_lstm/scaling.py rename egs/librispeech/ASR/transducer_lstm/{test_encoder.py => test_model.py} (60%) diff --git a/egs/librispeech/ASR/transducer_lstm/beam_search.py b/egs/librispeech/ASR/transducer_lstm/beam_search.py index 08cb32ef7..8554e44cc 120000 --- a/egs/librispeech/ASR/transducer_lstm/beam_search.py +++ b/egs/librispeech/ASR/transducer_lstm/beam_search.py @@ -1 +1 @@ -../transducer_stateless/beam_search.py \ No newline at end of file +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/transducer_lstm/decode.py b/egs/librispeech/ASR/transducer_lstm/decode.py index 8db103672..30d5d15a4 100755 --- a/egs/librispeech/ASR/transducer_lstm/decode.py +++ b/egs/librispeech/ASR/transducer_lstm/decode.py @@ -19,16 +19,16 @@ Usage: (1) greedy search ./transducer_lstm/decode.py \ - --epoch 14 \ - --avg 7 \ + --epoch 28 \ + --avg 15 \ --exp-dir ./transducer_lstm/exp \ --max-duration 100 \ --decoding-method greedy_search (2) beam search ./transducer_lstm/decode.py \ - --epoch 14 \ - --avg 7 \ + --epoch 28 \ + --avg 15 \ --exp-dir ./transducer_lstm/exp \ --max-duration 100 \ --decoding-method beam_search \ @@ -36,12 +36,23 @@ Usage: (3) modified beam search ./transducer_lstm/decode.py \ - --epoch 14 \ - --avg 7 \ + --epoch 28 \ + --avg 15 \ --exp-dir ./transducer_lstm/exp \ --max-duration 100 \ --decoding-method modified_beam_search \ --beam-size 4 + +(4) fast beam search +./transducer_lstm/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./transducer_lstm/exp \ + --max-duration 1500 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 """ @@ -49,21 +60,27 @@ import argparse import logging from collections import defaultdict from pathlib import Path -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple +import k2 import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from beam_search import ( beam_search, + fast_beam_search, greedy_search, greedy_search_batch, modified_beam_search, ) from train import get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import ( AttributeDict, setup_logger, @@ -80,17 +97,29 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=29, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + parser.add_argument( "--avg", type=int, - default=13, + default=15, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + "'--epoch' and '--iter'", ) parser.add_argument( @@ -115,6 +144,7 @@ def get_parser(): - greedy_search - beam_search - modified_beam_search + - fast_beam_search """, ) @@ -122,8 +152,35 @@ def get_parser(): "--beam-size", type=int, default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, help="""Used only when --decoding-method is - beam_search or modified_beam_search""", + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search""", ) parser.add_argument( @@ -149,6 +206,7 @@ def decode_one_batch( model: nn.Module, sp: spm.SentencePieceProcessor, batch: dict, + decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -171,6 +229,9 @@ def decode_one_batch( It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -188,24 +249,41 @@ def decode_one_batch( encoder_out, encoder_out_lens = model.encoder( x=feature, x_lens=feature_lens ) - hyp_list: List[List[int]] = [] + hyps = [] - if ( + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif ( params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1 ): - hyp_list = greedy_search_batch( + hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search": - hyp_list = modified_beam_search( + hyp_tokens = modified_beam_search( model=model, encoder_out=encoder_out, beam=params.beam_size, ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) else: batch_size = encoder_out.size(0) + for i in range(batch_size): # fmt: off encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] @@ -226,14 +304,20 @@ def decode_one_batch( raise ValueError( f"Unsupported decoding method: {params.decoding_method}" ) - hyp_list.append(hyp) - - hyps = [sp.decode(hyp).split() for hyp in hyp_list] + hyps.append(sp.decode(hyp).split()) if params.decoding_method == "greedy_search": return {"greedy_search": hyps} + elif params.decoding_method == "fast_beam_search": + return { + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps + } else: - return {f"beam_{params.beam_size}": hyps} + return {f"beam_size_{params.beam_size}": hyps} def decode_dataset( @@ -241,6 +325,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -253,6 +338,9 @@ def decode_dataset( The neural model. sp: The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -280,6 +368,7 @@ def decode_dataset( params=params, model=model, sp=sp, + decoding_graph=decoding_graph, batch=batch, ) @@ -360,13 +449,24 @@ def main(): assert params.decoding_method in ( "greedy_search", "beam_search", + "fast_beam_search", "modified_beam_search", ) params.res_dir = params.exp_dir / params.decoding_method - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if "beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam_size}" + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + elif "beam_search" in params.decoding_method: + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -383,8 +483,9 @@ def main(): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # is defined in local/train_bpe_model.py + # and is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(params) @@ -392,7 +493,24 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - if params.avg == 1: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: start = params.epoch - params.avg + 1 @@ -408,6 +526,11 @@ def main(): model.eval() model.device = device + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -428,6 +551,7 @@ def main(): params=params, model=model, sp=sp, + decoding_graph=decoding_graph, ) save_results( diff --git a/egs/librispeech/ASR/transducer_lstm/decoder.py b/egs/librispeech/ASR/transducer_lstm/decoder.py deleted file mode 100644 index b82fed37b..000000000 --- a/egs/librispeech/ASR/transducer_lstm/decoder.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class Decoder(nn.Module): - """This class modifies the stateless decoder from the following paper: - - RNN-transducer with stateless prediction network - https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 - - It removes the recurrent connection from the decoder, i.e., the prediction - network. Different from the above paper, it adds an extra Conv1d - right after the embedding layer. - - TODO: Implement https://arxiv.org/pdf/2109.07513.pdf - """ - - def __init__( - self, - vocab_size: int, - embedding_dim: int, - blank_id: int, - context_size: int, - ): - """ - Args: - vocab_size: - Number of tokens of the modeling unit including blank. - embedding_dim: - Dimension of the input embedding. - blank_id: - The ID of the blank symbol. - context_size: - Number of previous words to use to predict the next word. - 1 means bigram; 2 means trigram. n means (n+1)-gram. - """ - super().__init__() - self.embedding = nn.Embedding( - num_embeddings=vocab_size, - embedding_dim=embedding_dim, - padding_idx=blank_id, - ) - self.blank_id = blank_id - - assert context_size >= 1, context_size - self.context_size = context_size - if context_size > 1: - self.conv = nn.Conv1d( - in_channels=embedding_dim, - out_channels=embedding_dim, - kernel_size=context_size, - padding=0, - groups=embedding_dim, - bias=False, - ) - - def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: - """ - Args: - y: - A 2-D tensor of shape (N, U). - need_pad: - True to left pad the input. Should be True during training. - False to not pad the input. Should be False during inference. - Returns: - Return a tensor of shape (N, U, embedding_dim). - """ - embedding_out = self.embedding(y) - if self.context_size > 1: - embedding_out = embedding_out.permute(0, 2, 1) - if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) - else: - # During inference time, there is no need to do extra padding - # as we only need one output - assert embedding_out.size(-1) == self.context_size - embedding_out = self.conv(embedding_out) - embedding_out = embedding_out.permute(0, 2, 1) - return embedding_out diff --git a/egs/librispeech/ASR/transducer_lstm/decoder.py b/egs/librispeech/ASR/transducer_lstm/decoder.py new file mode 120000 index 000000000..0793c5709 --- /dev/null +++ b/egs/librispeech/ASR/transducer_lstm/decoder.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/transducer_lstm/joiner.py b/egs/librispeech/ASR/transducer_lstm/joiner.py deleted file mode 100644 index 8c3710011..000000000 --- a/egs/librispeech/ASR/transducer_lstm/joiner.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class Joiner(nn.Module): - def __init__(self, input_dim: int, output_dim: int): - super().__init__() - - self.output_linear = nn.Linear(input_dim, output_dim) - - def forward( - self, encoder_out: torch.Tensor, decoder_out: torch.Tensor, *unused - ) -> torch.Tensor: - """ - Args: - encoder_out: - Output from the encoder. Its shape is (N, T, C). - decoder_out: - Output from the decoder. Its shape is (N, U, C). - Returns: - Return a tensor of shape (N, T, U, C). - """ - assert encoder_out.ndim == decoder_out.ndim == 3 - assert encoder_out.size(0) == decoder_out.size(0) - assert encoder_out.size(2) == decoder_out.size(2) - - encoder_out = encoder_out.unsqueeze(2) - # Now encoder_out is (N, T, 1, C) - - decoder_out = decoder_out.unsqueeze(1) - # Now decoder_out is (N, 1, U, C) - - logit = encoder_out + decoder_out - logit = F.relu(logit) - - output = self.output_linear(logit) - if not self.training: - output = output.squeeze(2).squeeze(1) - - return output diff --git a/egs/librispeech/ASR/transducer_lstm/joiner.py b/egs/librispeech/ASR/transducer_lstm/joiner.py new file mode 120000 index 000000000..815fd4bb6 --- /dev/null +++ b/egs/librispeech/ASR/transducer_lstm/joiner.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/joiner.py \ No newline at end of file diff --git a/egs/librispeech/ASR/transducer_lstm/model.py b/egs/librispeech/ASR/transducer_lstm/model.py deleted file mode 100644 index 02c5eabb0..000000000 --- a/egs/librispeech/ASR/transducer_lstm/model.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Note we use `rnnt_loss` from torchaudio, which exists only in -torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0 -""" -import k2 -import torch -import torch.nn as nn -import torchaudio -import torchaudio.functional -from encoder_interface import EncoderInterface - -from icefall.utils import add_sos - - -class Transducer(nn.Module): - """It implements https://arxiv.org/pdf/1211.3711.pdf - "Sequence Transduction with Recurrent Neural Networks" - """ - - def __init__( - self, - encoder: EncoderInterface, - decoder: nn.Module, - joiner: nn.Module, - ): - """ - Args: - encoder: - It is the transcription network in the paper. Its accepts - two inputs: `x` of (N, T, C) and `x_lens` of shape (N,). - It returns two tensors: `logits` of shape (N, T, C) and - `logit_lens` of shape (N,). - decoder: - It is the prediction network in the paper. Its input shape - is (N, U) and its output shape is (N, U, C). It should contain - one attribute: `blank_id`. - joiner: - It has two inputs with shapes: (N, T, C) and (N, U, C). Its - output shape is (N, T, U, C). Note that its output contains - unnormalized probs, i.e., not processed by log-softmax. - """ - super().__init__() - assert isinstance(encoder, EncoderInterface) - assert hasattr(decoder, "blank_id") - - self.encoder = encoder - self.decoder = decoder - self.joiner = joiner - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - y: k2.RaggedTensor, - ) -> torch.Tensor: - """ - Args: - x: - A 3-D tensor of shape (N, T, C). - x_lens: - A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. - y: - A ragged tensor with 2 axes [utt][label]. It contains labels of each - utterance. - Returns: - Return the transducer loss. - """ - assert x.ndim == 3, x.shape - assert x_lens.ndim == 1, x_lens.shape - assert y.num_axes == 2, y.num_axes - - assert x.size(0) == x_lens.size(0) == y.dim0 - - encoder_out, x_lens = self.encoder(x, x_lens) - assert torch.all(x_lens > 0) - - # Now for the decoder, i.e., the prediction network - row_splits = y.shape.row_splits(1) - y_lens = row_splits[1:] - row_splits[:-1] - - blank_id = self.decoder.blank_id - sos_y = add_sos(y, sos_id=blank_id) - - sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) - sos_y_padded = sos_y_padded.to(torch.int64) - - decoder_out = self.decoder(sos_y_padded) - - logits = self.joiner(encoder_out, decoder_out) - - # rnnt_loss requires 0 padded targets - # Note: y does not start with SOS - y_padded = y.pad(mode="constant", padding_value=0) - - assert hasattr(torchaudio.functional, "rnnt_loss"), ( - f"Current torchaudio version: {torchaudio.__version__}\n" - "Please install a version >= 0.10.0" - ) - - loss = torchaudio.functional.rnnt_loss( - logits=logits, - targets=y_padded, - logit_lengths=x_lens, - target_lengths=y_lens, - blank=blank_id, - reduction="sum", - ) - - return loss diff --git a/egs/librispeech/ASR/transducer_lstm/model.py b/egs/librispeech/ASR/transducer_lstm/model.py new file mode 120000 index 000000000..ebb6d774d --- /dev/null +++ b/egs/librispeech/ASR/transducer_lstm/model.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/model.py \ No newline at end of file diff --git a/egs/librispeech/ASR/transducer_lstm/optim.py b/egs/librispeech/ASR/transducer_lstm/optim.py new file mode 120000 index 000000000..e2deb4492 --- /dev/null +++ b/egs/librispeech/ASR/transducer_lstm/optim.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/transducer_lstm/scaling.py b/egs/librispeech/ASR/transducer_lstm/scaling.py new file mode 120000 index 000000000..09d802cc4 --- /dev/null +++ b/egs/librispeech/ASR/transducer_lstm/scaling.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/transducer_lstm/test_encoder.py b/egs/librispeech/ASR/transducer_lstm/test_model.py similarity index 60% rename from egs/librispeech/ASR/transducer_lstm/test_encoder.py rename to egs/librispeech/ASR/transducer_lstm/test_model.py index cad5f1148..acd71455d 100755 --- a/egs/librispeech/ASR/transducer_lstm/test_encoder.py +++ b/egs/librispeech/ASR/transducer_lstm/test_model.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -15,33 +14,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + + """ To run this file, do: cd icefall/egs/librispeech/ASR - python ./transducer_lstm/test_encoder.py + python ./pruned_transducer_stateless4/test_model.py """ -from encoder import LstmEncoder +from train import get_params, get_transducer_model -def test_encoder(): - encoder = LstmEncoder( - num_features=80, - hidden_size=1024, - proj_size=512, - output_dim=512, - subsampling_factor=4, - num_encoder_layers=12, - ) - num_params = sum(p.numel() for p in encoder.parameters() if p.requires_grad) - print(num_params) - # 93979284 - # 66427392 +def test_model(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + model = get_transducer_model(params) + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") def main(): - test_encoder() + test_model() if __name__ == "__main__": diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py index 98859a58f..9cff0fa6f 100755 --- a/egs/librispeech/ASR/transducer_lstm/train.py +++ b/egs/librispeech/ASR/transducer_lstm/train.py @@ -16,20 +16,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """ Usage: -export CUDA_VISIBLE_DEVICES="0,1,2" +export CUDA_VISIBLE_DEVICES="0,1,2,3" ./transducer_lstm/train.py \ - --world-size 3 \ + --world-size 4 \ --num-epochs 30 \ --start-epoch 0 \ --exp-dir transducer_lstm/exp \ --full-libri 1 \ - --max-duration 400 \ - --lr-factor 3 + --max-duration 300 + +# For mix precision training: + +./transducer_lstm/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 0 \ + --use-fp16 1 \ + --exp-dir transducer_lstm/exp \ + --full-libri 1 \ + --max-duration 550 + """ @@ -38,32 +48,40 @@ import logging import warnings from pathlib import Path from shutil import copyfile -from typing import Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union import k2 +import optim import sentencepiece as spm import torch import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from decoder import Decoder from encoder import LstmEncoder +from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import Transducer -from noam import Noam +from optim import Eden, Eve from torch import Tensor +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter -from icefall.checkpoint import load_checkpoint +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] + def get_parser(): parser = argparse.ArgumentParser( @@ -104,7 +122,16 @@ def get_parser(): default=0, help="""Resume training from from this epoch. If it is positive, it will load checkpoint from - transducer_lstm/exp/epoch-{start_epoch-1}.pt + transducer_stateless2/exp/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt """, ) @@ -126,10 +153,68 @@ def get_parser(): ) parser.add_argument( - "--lr-factor", + "--initial-lr", type=float, - default=3.0, - help="The lr_factor for Noam optimizer", + default=0.003, + help="The initial learning rate. This value should not need to be changed.", + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate decreases. + We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=6, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" + "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -140,11 +225,41 @@ def get_parser(): ) parser.add_argument( - "--context-size", + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--save-every-n", type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + default=8000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=20, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", ) return parser @@ -188,15 +303,10 @@ def get_params() -> AttributeDict: - subsampling_factor: The subsampling factor for the model. - - use_feat_batchnorm: Whether to do batch normalization for the - input features. - - - attention_dim: Hidden dim for multi-head attention model. + - encoder_dim: Hidden dim for multi-head attention model. - num_decoder_layers: Number of decoder layer of transformer decoder. - - weight_decay: The weight_decay for the optimizer. - - warm_step: The warm_step for Noam optimizer. """ params = AttributeDict( @@ -209,21 +319,20 @@ def get_params() -> AttributeDict: "log_interval": 50, "reset_interval": 200, "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for conformer + # parameters for encoder "feature_dim": 80, - "encoder_out_dim": 512, "subsampling_factor": 4, + "encoder_dim": 512, "encoder_hidden_size": 1024, "num_encoder_layers": 4, "proj_size": 512, "vgg_frontend": False, - # decoder params - "decoder_embedding_dim": 1024, - "num_decoder_layers": 4, - "decoder_hidden_dim": 512, + # parameters for decoder + "decoder_dim": 512, + # parameters for joiner + "joiner_dim": 512, # parameters for Noam - "weight_decay": 1e-6, - "warm_step": 80000, # For the 100h subset, use 8k + "model_warm_step": 3000, # arg given to model, not for lrate "env_info": get_env_info(), } ) @@ -231,11 +340,11 @@ def get_params() -> AttributeDict: return params -def get_encoder_model(params: AttributeDict): +def get_encoder_model(params: AttributeDict) -> nn.Module: encoder = LstmEncoder( num_features=params.feature_dim, hidden_size=params.encoder_hidden_size, - output_dim=params.encoder_out_dim, + output_dim=params.encoder_dim, subsampling_factor=params.subsampling_factor, num_encoder_layers=params.num_encoder_layers, vgg_frontend=params.vgg_frontend, @@ -246,22 +355,24 @@ def get_encoder_model(params: AttributeDict): def get_decoder_model(params: AttributeDict) -> nn.Module: decoder = Decoder( vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, + decoder_dim=params.decoder_dim, blank_id=params.blank_id, context_size=params.context_size, ) return decoder -def get_joiner_model(params: AttributeDict): +def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( - input_dim=params.encoder_out_dim, - output_dim=params.vocab_size, + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, ) return joiner -def get_transducer_model(params: AttributeDict): +def get_transducer_model(params: AttributeDict) -> nn.Module: encoder = get_encoder_model(params) decoder = get_decoder_model(params) joiner = get_joiner_model(params) @@ -270,6 +381,10 @@ def get_transducer_model(params: AttributeDict): encoder=encoder, decoder=decoder, joiner=joiner, + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, ) return model @@ -278,15 +393,17 @@ def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, -) -> None: + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: """Load checkpoint from file. - If params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. Otherwise, this function does nothing. + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. - Apart from loading state dict for `model`, `optimizer` and `scheduler`, - it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, and `best_valid_loss` in `params`. Args: @@ -297,14 +414,19 @@ def load_checkpoint_if_available( optimizer: The optimizer that we are using. scheduler: - The learning rate scheduler we are using. + The scheduler that we are using. Returns: - Return None. + Return a dict containing previously saved training info. """ - if params.start_epoch <= 0: - return + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 0: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" saved_params = load_checkpoint( filename, model=model, @@ -322,6 +444,13 @@ def load_checkpoint_if_available( for k in keys: params[k] = saved_params[k] + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + return saved_params @@ -329,7 +458,9 @@ def save_checkpoint( params: AttributeDict, model: nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -339,6 +470,12 @@ def save_checkpoint( It is returned by :func:`get_params`. model: The training model. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. """ if rank != 0: return @@ -349,6 +486,8 @@ def save_checkpoint( params=params, optimizer=optimizer, scheduler=scheduler, + sampler=sampler, + scaler=scaler, rank=rank, ) @@ -367,6 +506,7 @@ def compute_loss( sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, + warmup: float = 1.0, ) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -383,6 +523,8 @@ def compute_loss( True for training. False for validation. When it is True, this function enables autograd during computation; when it is False, it disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. """ device = model.device feature = batch["inputs"] @@ -398,21 +540,42 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - loss = model(x=feature, x_lens=feature_lens, y=y) + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + warmup=warmup, + ) + # after the main warmup step, we keep pruned_loss_scale small + # for the same amount of time (model_warm_step), to avoid + # overwhelming the simple_loss and causing it to diverge, + # in case it had not fully learned the alignment yet. + pruned_loss_scale = ( + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss + ) assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() return loss, info @@ -455,11 +618,14 @@ def train_one_epoch( params: AttributeDict, model: nn.Module, optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, + rank: int = 0, ) -> None: """Train the model for one epoch. @@ -474,51 +640,96 @@ def train_one_epoch( The model for training. optimizer: The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. train_dl: Dataloader for the training dataset. valid_dl: Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. tb_writer: Writer to write log messages to tensorboard. world_size: Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. """ model.train() tot_loss = MetricsTracker() + cur_batch_idx = params.get("cur_batch_idx", 0) + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=(params.batch_idx_train / params.model_warm_step), + ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. - + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() optimizer.zero_grad() - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - if batch_idx % params.log_interval == 0: - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}" + if params.print_diagnostics and batch_idx == 5: + return + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, ) if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" + ) if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) @@ -564,8 +775,7 @@ def run(rank, world_size, args): params = get_params() params.update(vars(args)) if params.full_libri is False: - params.valid_interval = 800 - params.warm_step = 8000 + params.valid_interval = 1600 fix_random_seed(params.seed) if world_size > 1: @@ -596,29 +806,39 @@ def run(rank, world_size, args): logging.info("About to create model") model = get_transducer_model(params) - checkpoints = load_checkpoint_if_available(params=params, model=model) - - num_param = sum([p.numel() for p in model.parameters() if p.requires_grad]) + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + checkpoints = load_checkpoint_if_available(params=params, model=model) + model.to(device) if world_size > 1: logging.info("Using DDP") model = DDP(model, device_ids=[rank]) model.device = device - optimizer = Noam( - model.parameters(), - model_size=params.encoder_hidden_size, - factor=params.lr_factor, - warm_step=params.warm_step, - weight_decay=params.weight_decay, - ) + optimizer = Eve(model.parameters(), lr=params.initial_lr) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") optimizer.load_state_dict(checkpoints["optimizer"]) + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2 ** 22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + librispeech = LibriSpeechAsrDataModule(args) train_cuts = librispeech.train_clean_100_cuts() @@ -628,75 +848,81 @@ def run(rank, world_size, args): def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold return 1.0 <= c.duration <= 20.0 - num_in_total = len(train_cuts) - train_cuts = train_cuts.filter(remove_short_and_long_utt) - try: - num_left = len(train_cuts) - num_removed = num_in_total - num_left - removed_percent = num_removed / num_in_total * 100 + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None - logging.info( - f"Before removing short and long utterances: {num_in_total}" - ) - logging.info(f"After removing short and long utterances: {num_left}") - logging.info( - f"Removed {num_removed} utterances ({removed_percent:.5f}%)" - ) - except TypeError as e: - # You can ignore this error as previous versions of Lhotse work fine - # for the above code. In recent versions of Lhotse, it uses - # lazy filter, producing cutsets that don't have the __len__ method - logging.info(str(e)) - - train_dl = librispeech.train_dataloaders(train_cuts) + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) valid_cuts = librispeech.dev_clean_cuts() valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) for epoch in range(params.start_epoch, params.num_epochs): + scheduler.step_epoch(epoch) fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) - cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - if rank == 0: - logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) - params.cur_epoch = epoch train_one_epoch( params=params, model=model, optimizer=optimizer, + scheduler=scheduler, sp=sp, train_dl=train_dl, valid_dl=valid_dl, + scaler=scaler, tb_writer=tb_writer, world_size=world_size, + rank=rank, ) + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + save_checkpoint( params=params, model=model, optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, rank=rank, ) @@ -723,17 +949,21 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - optimizer.zero_grad() - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) + # warmup = 0.0 is so that the derivs for the pruned loss stay zero + # (i.e. are not remembered by the decaying-average in adam), because + # we want to avoid these params being subject to shrinkage in adam. + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=0.0, + ) loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() + optimizer.zero_grad() except RuntimeError as e: if "CUDA out of memory" in str(e): logging.error(