From 7c5dba62e006a8e78946f93717eea02b4cb96d16 Mon Sep 17 00:00:00 2001 From: Desh Raj Date: Sat, 11 Mar 2023 12:07:23 -0500 Subject: [PATCH] update results --- egs/librispeech/ASR/RESULTS.md | 21 +- .../ASR/zipformer_ctc/asr_datamodule.py | 2 +- egs/librispeech/ASR/zipformer_ctc/decode.py | 11 +- egs/librispeech/ASR/zipformer_ctc/decoder.py | 83 ++++ egs/librispeech/ASR/zipformer_ctc/model.py | 1 + .../ASR/zipformer_ctc/pretrained.py | 430 ------------------ 6 files changed, 105 insertions(+), 443 deletions(-) delete mode 100755 egs/librispeech/ASR/zipformer_ctc/pretrained.py diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index becb8d408..c958951a5 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -4,20 +4,19 @@ #### [zipformer_ctc](./zipformer_ctc) -See <> for more details. +See for more details. You can find a pretrained model, training logs, decoding logs, and decoding results at: -<> + Number of model parameters: 86083707, i.e., 86.08 M | decoding method | test-clean | test-other | comment | |-------------------------|------------|------------|---------------------| -| ctc-decoding | 2.50 | 5.86 | --epoch 30 --avg 10 | -| whole-lattice-rescoring | | | --epoch 30 --avg 10 | -| attention-rescoring | | | --epoch 30 --avg 10 | -| rnn-lm | | | --epoch 30 --avg 10 | +| ctc-decoding | 2.50 | 5.86 | --epoch 30 --avg 9 | +| whole-lattice-rescoring | 2.44 | 5.38 | --epoch 30 --avg 9 | +| attention-rescoring | 2.35 | 5.16 | --epoch 30 --avg 9 | The training command is: @@ -36,6 +35,16 @@ The training command is: The tensorboard log can be found at +The decoding command is: + +```bash +./zipformer_ctc/decode.py \ + --epoch 99 --avg 1 --use-averaged-model False \ + --exp-dir zipformer_ctc/exp \ + --lang-dir data/lang_bpe_500 \ + --lm-dir data/lm \ + --method ctc-decoding + ### Streaming Zipformer-Transducer (Pruned Stateless Transducer + Streaming Zipformer) #### [pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming) diff --git a/egs/librispeech/ASR/zipformer_ctc/asr_datamodule.py b/egs/librispeech/ASR/zipformer_ctc/asr_datamodule.py index d5190856d..fa1b8cca3 120000 --- a/egs/librispeech/ASR/zipformer_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/zipformer_ctc/asr_datamodule.py @@ -1 +1 @@ -/exp/draj/mini_scale_2022/icefall/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py \ No newline at end of file +../tdnn_lstm_ctc/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_ctc/decode.py b/egs/librispeech/ASR/zipformer_ctc/decode.py index 6e614b446..7f605e2c8 100755 --- a/egs/librispeech/ASR/zipformer_ctc/decode.py +++ b/egs/librispeech/ASR/zipformer_ctc/decode.py @@ -313,7 +313,7 @@ def decode_one_batch( feature_lens = supervisions["num_frames"].to(device) nnet_output, _ = model.encoder(feature, feature_lens) - nnet_output = model.ctc_output(nnet_output) + ctc_output = model.ctc_output(nnet_output) # nnet_output is (N, T, C) supervision_segments = torch.stack( @@ -334,7 +334,7 @@ def decode_one_batch( decoding_graph = H lattice = get_lattice( - nnet_output=nnet_output, + nnet_output=ctc_output, decoding_graph=decoding_graph, supervision_segments=supervision_segments, search_beam=params.search_beam, @@ -407,12 +407,11 @@ def decode_one_batch( ] lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] - lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] - lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] nnet_output = nnet_output.permute(1, 0, 2) # (N, T, C) -> (T, N, C) mask = encoder_padding_mask(nnet_output.size(0), supervisions) mask = mask.to(nnet_output.device) if mask is not None else None + mmodel = model.decoder.module if hasattr(model.decoder, "module") else model.decoder if params.method == "nbest-rescoring": best_path_dict = rescore_with_n_best_list( @@ -439,7 +438,7 @@ def decode_one_batch( best_path_dict = rescore_with_attention_decoder( lattice=rescored_lattice, num_paths=params.num_paths, - model=model, + model=mmodel, memory=nnet_output, memory_key_padding_mask=mask, sos_id=sos_id, @@ -458,7 +457,7 @@ def decode_one_batch( lattice=rescored_lattice, num_paths=params.num_paths, rnn_lm_model=rnn_lm_model, - model=model, + model=mmodel, memory=nnet_output, memory_key_padding_mask=mask, sos_id=sos_id, diff --git a/egs/librispeech/ASR/zipformer_ctc/decoder.py b/egs/librispeech/ASR/zipformer_ctc/decoder.py index 784d3837b..8dec048a1 100644 --- a/egs/librispeech/ASR/zipformer_ctc/decoder.py +++ b/egs/librispeech/ASR/zipformer_ctc/decoder.py @@ -145,6 +145,84 @@ class Decoder(nn.Module): return decoder_loss + @torch.jit.export + def decoder_nll( + self, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + token_ids: List[torch.Tensor], + sos_id: int, + eos_id: int, + ) -> torch.Tensor: + """ + Args: + memory: + It's the output of the encoder with shape (T, N, C) + memory_key_padding_mask: + The padding mask from the encoder. + token_ids: + A list-of-list IDs (e.g., word piece IDs). + Each sublist represents an utterance. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + Returns: + A 2-D tensor of shape (len(token_ids), max_token_length) + representing the cross entropy loss (i.e., negative log-likelihood). + """ + # The common part between this function and decoder_forward could be + # extracted as a separate function. + if isinstance(token_ids[0], torch.Tensor): + # This branch is executed by torchscript in C++. + # See https://github.com/k2-fsa/k2/pull/870 + # https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286 + token_ids = [tolist(t) for t in token_ids] + + ys_in = add_sos(token_ids, sos_id=sos_id) + ys_in = [torch.tensor(y) for y in ys_in] + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + + ys_out = add_eos(token_ids, eos_id=eos_id) + ys_out = [torch.tensor(y) for y in ys_out] + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + + device = memory.device + ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) + ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) + + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + # TODO: Use length information to create the decoder padding mask + # We set the first column to False since the first column in ys_in_pad + # contains sos_id, which is the same as eos_id in our current setting. + tgt_key_padding_mask[:, 0] = False + + tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) + tgt = self.decoder_pos(tgt) + tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + pred_pad = self.decoder( + tgt=tgt, + memory=memory, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + ) # (T, B, F) + pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F) + pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F) + # nll: negative log-likelihood + nll = torch.nn.functional.cross_entropy( + pred_pad.view(-1, self.decoder_num_class), + ys_out_pad.view(-1), + ignore_index=-1, + reduction="none", + ) + + nll = nll.view(pred_pad.shape[0], -1) + + return nll + def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]: """Prepend sos_id to each utterance. @@ -213,3 +291,8 @@ def generate_square_subsequent_mask(sz: int) -> torch.Tensor: .masked_fill(mask == 1, float(0.0)) ) return mask + + +def tolist(t: torch.Tensor) -> List[int]: + """Used by jit""" + return torch.jit.annotate(List[int], t.tolist()) diff --git a/egs/librispeech/ASR/zipformer_ctc/model.py b/egs/librispeech/ASR/zipformer_ctc/model.py index 560845339..2aeb8a072 100644 --- a/egs/librispeech/ASR/zipformer_ctc/model.py +++ b/egs/librispeech/ASR/zipformer_ctc/model.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List import k2 import torch diff --git a/egs/librispeech/ASR/zipformer_ctc/pretrained.py b/egs/librispeech/ASR/zipformer_ctc/pretrained.py deleted file mode 100755 index 30def9c40..000000000 --- a/egs/librispeech/ASR/zipformer_ctc/pretrained.py +++ /dev/null @@ -1,430 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from conformer import Conformer -from torch.nn.utils.rnn import pad_sequence - -from icefall.decode import ( - get_lattice, - one_best_decoding, - rescore_with_attention_decoder, - rescore_with_whole_lattice, -) -from icefall.utils import AttributeDict, get_texts - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--words-file", - type=str, - help="""Path to words.txt. - Used only when method is not ctc-decoding. - """, - ) - - parser.add_argument( - "--HLG", - type=str, - help="""Path to HLG.pt. - Used only when method is not ctc-decoding. - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to bpe.model. - Used only when method is ctc-decoding. - """, - ) - - parser.add_argument( - "--method", - type=str, - default="1best", - help="""Decoding method. - Possible values are: - (0) ctc-decoding - Use CTC decoding. It uses a sentence - piece model, i.e., lang_dir/bpe.model, to convert - word pieces to words. It needs neither a lexicon - nor an n-gram LM. - (1) 1best - Use the best path as decoding output. Only - the transformer encoder output is used for decoding. - We call it HLG decoding. - (2) whole-lattice-rescoring - Use an LM to rescore the - decoding lattice and then use 1best to decode the - rescored lattice. - We call it HLG decoding + n-gram LM rescoring. - (3) attention-decoder - Extract n paths from the rescored - lattice and use the transformer attention decoder for - rescoring. - We call it HLG decoding + n-gram LM rescoring + attention - decoder rescoring. - """, - ) - - parser.add_argument( - "--G", - type=str, - help="""An LM for rescoring. - Used only when method is - whole-lattice-rescoring or attention-decoder. - It's usually a 4-gram LM. - """, - ) - - parser.add_argument( - "--num-paths", - type=int, - default=100, - help=""" - Used only when method is attention-decoder. - It specifies the size of n-best list.""", - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=1.3, - help=""" - Used only when method is whole-lattice-rescoring and attention-decoder. - It specifies the scale for n-gram LM scores. - (Note: You need to tune it on a dataset.) - """, - ) - - parser.add_argument( - "--attention-decoder-scale", - type=float, - default=1.2, - help=""" - Used only when method is attention-decoder. - It specifies the scale for attention decoder scores. - (Note: You need to tune it on a dataset.) - """, - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help=""" - Used only when method is attention-decoder. - It specifies the scale for lattice.scores when - extracting n-best lists. A smaller value results in - more unique number of paths with the risk of missing - the best path. - """, - ) - - parser.add_argument( - "--sos-id", - type=int, - default=1, - help=""" - Used only when method is attention-decoder. - It specifies ID for the SOS token. - """, - ) - - parser.add_argument( - "--num-classes", - type=int, - default=500, - help=""" - Vocab size in the BPE model. - """, - ) - - parser.add_argument( - "--eos-id", - type=int, - default=1, - help=""" - Used only when method is attention-decoder. - It specifies ID for the EOS token. - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "sample_rate": 16000, - # parameters for conformer - "subsampling_factor": 4, - "vgg_frontend": False, - "use_feat_batchnorm": True, - "feature_dim": 80, - "nhead": 8, - "attention_dim": 512, - "num_decoder_layers": 6, - # parameters for decoding - "search_beam": 20, - "output_beam": 8, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - } - ) - return params - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - if args.method != "attention-decoder": - # to save memory as the attention decoder - # will not be used - params.num_decoder_layers = 0 - - params.update(vars(args)) - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = Conformer( - num_features=params.feature_dim, - nhead=params.nhead, - d_model=params.attention_dim, - num_classes=params.num_classes, - subsampling_factor=params.subsampling_factor, - num_decoder_layers=params.num_decoder_layers, - vgg_frontend=params.vgg_frontend, - use_feat_batchnorm=params.use_feat_batchnorm, - ) - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - # Note: We don't use key padding mask for attention during decoding - with torch.no_grad(): - nnet_output, memory, memory_key_padding_mask = model(features) - - batch_size = nnet_output.shape[0] - supervision_segments = torch.tensor( - [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], - dtype=torch.int32, - ) - - if params.method == "ctc-decoding": - logging.info("Use CTC decoding") - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(params.bpe_model) - max_token_id = params.num_classes - 1 - - H = k2.ctc_topo( - max_token=max_token_id, - modified=params.num_classes > 500, - device=device, - ) - - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=H, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - token_ids = get_texts(best_path) - hyps = bpe_model.decode(token_ids) - hyps = [s.split() for s in hyps] - elif params.method in [ - "1best", - "whole-lattice-rescoring", - "attention-decoder", - ]: - logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) - HLG = HLG.to(device) - if not hasattr(HLG, "lm_scores"): - # For whole-lattice-rescoring and attention-decoder - HLG.lm_scores = HLG.scores.clone() - - if params.method in [ - "whole-lattice-rescoring", - "attention-decoder", - ]: - logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) - # Add epsilon self-loops to G as we will compose - # it with the whole lattice later - G = G.to(device) - G = k2.add_epsilon_self_loops(G) - G = k2.arc_sort(G) - G.lm_scores = G.scores.clone() - - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - if params.method == "1best": - logging.info("Use HLG decoding") - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - elif params.method == "whole-lattice-rescoring": - logging.info("Use HLG decoding + LM rescoring") - best_path_dict = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=[params.ngram_lm_scale], - ) - best_path = next(iter(best_path_dict.values())) - elif params.method == "attention-decoder": - logging.info("Use HLG + LM rescoring + attention decoder rescoring") - rescored_lattice = rescore_with_whole_lattice( - lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None - ) - best_path_dict = rescore_with_attention_decoder( - lattice=rescored_lattice, - num_paths=params.num_paths, - model=model, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - sos_id=params.sos_id, - eos_id=params.eos_id, - nbest_scale=params.nbest_scale, - ngram_lm_scale=params.ngram_lm_scale, - attention_scale=params.attention_decoder_scale, - ) - best_path = next(iter(best_path_dict.values())) - - hyps = get_texts(best_path) - word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] - else: - raise ValueError(f"Unsupported decoding method: {params.method}") - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main()