diff --git a/egs/tedlium3/ASR/RESULTS.md b/egs/tedlium3/ASR/RESULTS.md index e51bbb298..e89881c33 100644 --- a/egs/tedlium3/ASR/RESULTS.md +++ b/egs/tedlium3/ASR/RESULTS.md @@ -4,7 +4,7 @@ #### 2022-03-21 -Using the codes from this PR. +Using the codes from this PR https://github.com/k2-fsa/icefall/pull/261. The WERs are @@ -62,6 +62,18 @@ avg=13 --max-duration 100 \ --decoding-method modified_beam_search \ --beam-size 4 + +## fast beam search +./pruned_transducer_stateless/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir ./pruned_transducer_stateless/exp \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --max-duration 1500 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 ``` A pre-trained model and decoding logs can be found at @@ -85,6 +97,7 @@ The WERs are | greedy search | 7.19 | 6.70 | --epoch 29, --avg 11, --max-duration 100 | | beam search (beam size 4) | 7.02 | 6.36 | --epoch 29, --avg 11, --max-duration 100 | | modified beam search (beam size 4) | 6.91 | 6.33 | --epoch 29, --avg 11, --max-duration 100 | +| fast beam search (set as default) | 7.14 | 6.50 | --epoch 29, --avg 11, --max-duration 1500| The training command for reproducing is given below: diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py b/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py index c01fa966c..061d09e2f 100644 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py @@ -1,5 +1,5 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Mingshuang Luo) +# Copyright 2020 Xiaomi Corp. (authors: Fangjun Kuang +# Mingshuang Luo) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -18,14 +18,100 @@ from dataclasses import dataclass from typing import Dict, List, Optional +import k2 import torch from model import Transducer +from icefall.decode import one_best_decoding +from icefall.utils import get_texts + + +def fast_beam_search( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, +) -> List[List[int]]: + """It limits the maximum number of symbols per frame to 1. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a HLG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + Returns: + Return the decoded result. + """ + assert encoder_out.ndim == 3 + + context_size = model.decoder.context_size + vocab_size = model.decoder.vocab_size + unk_id = model.decoder.unk_id + + B, T, C = encoder_out.shape + + config = k2.RnntDecodingConfig( + vocab_size=vocab_size, + decoder_history_len=context_size, + beam=beam, + max_contexts=max_contexts, + max_states=max_states, + ) + individual_streams = [] + for i in range(B): + individual_streams.append(k2.RnntDecodingStream(decoding_graph)) + decoding_streams = k2.RnntDecodingStreams(individual_streams, config) + + for t in range(T): + # shape is a RaggedShape of shape (B, context) + # contexts is a Tensor of shape (shape.NumElements(), context_size) + shape, contexts = decoding_streams.get_contexts() + # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts = contexts.to(torch.int64) + # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) + decoder_out = model.decoder(contexts, need_pad=False) + # current_encoder_out is of shape + # (shape.NumElements(), 1, encoder_out_dim) + # fmt: off + current_encoder_out = torch.index_select( + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1) + ) + # fmt: on + logits = model.joiner( + current_encoder_out.unsqueeze(2), decoder_out.unsqueeze(1) + ) + logits = logits.squeeze(1).squeeze(1) + log_probs = logits.log_softmax(dim=-1) + decoding_streams.advance(log_probs) + decoding_streams.terminate_and_flush_to_streams() + lattice = decoding_streams.format_output(encoder_out_lens.tolist()) + best_path = one_best_decoding(lattice) + hyps = get_texts(best_path) + new_hyps = [] + for hyp in hyps: + hyp = [idx for idx in hyp if idx != unk_id] + new_hyps.append(hyp) + return new_hyps + def greedy_search( model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int ) -> List[int]: - """ + """Greedy search for a single utterance. Args: model: An instance of `Transducer`. @@ -98,6 +184,65 @@ def greedy_search( return hyp +def greedy_search_batch( + model: Transducer, encoder_out: torch.Tensor +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C), where N >= 1. + Returns: + Return a list-of-list integers containing the decoded results. + len(ans) equals to encoder_out.size(0). + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + device = model.device + + batch_size = encoder_out.size(0) + T = encoder_out.size(1) + + blank_id = model.decoder.blank_id + unk_id = model.decoder.unk_id + context_size = model.decoder.context_size + + hyps = [[blank_id] * context_size for _ in range(batch_size)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (batch_size, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + # decoder_out: (batch_size, 1, decoder_out_dim) + for t in range(T): + current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa + # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) + logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1)) + # logits'shape (batch_size, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id and v != unk_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps] + decoder_input = torch.tensor(decoder_input, device=device) + decoder_out = model.decoder(decoder_input, need_pad=False) + + ans = [h[context_size:] for h in hyps] + return ans + + @dataclass class Hypothesis: # The predicted tokens so far. @@ -132,8 +277,10 @@ class HypothesisList(object): def add(self, hyp: Hypothesis) -> None: """Add a Hypothesis to `self`. + If `hyp` already exists in `self`, its probability is updated using `log-sum-exp` with the existed one. + Args: hyp: The hypothesis to be added. @@ -150,6 +297,7 @@ class HypothesisList(object): def get_most_probable(self, length_norm: bool = False) -> Hypothesis: """Get the most probable hypothesis, i.e., the one with the largest `log_prob`. + Args: length_norm: If True, the `log_prob` of a hypothesis is normalized by the @@ -166,8 +314,10 @@ class HypothesisList(object): def remove(self, hyp: Hypothesis) -> None: """Remove a given hypothesis. + Caution: `self` is modified **in-place**. + Args: hyp: The hypothesis to be removed from `self`. @@ -180,8 +330,10 @@ class HypothesisList(object): def filter(self, threshold: torch.Tensor) -> "HypothesisList": """Remove all Hypotheses whose log_prob is less than threshold. + Caution: `self` is not modified. Instead, a new HypothesisList is returned. + Returns: Return a new HypothesisList containing all hypotheses from `self` with `log_prob` being greater than the given `threshold`. @@ -223,6 +375,7 @@ def modified_beam_search( beam: int = 4, ) -> List[int]: """It limits the maximum number of symbols per frame to 1. + Args: model: An instance of `Transducer`. @@ -324,7 +477,9 @@ def beam_search( ) -> List[int]: """ It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf + espnet/nets/beam_search_transducer.py#L247 is used as a reference. + Args: model: An instance of `Transducer`. @@ -346,7 +501,9 @@ def beam_search( device = model.device decoder_input = torch.tensor( - [blank_id] * context_size, device=device + [blank_id] * context_size, + device=device, + dtype=torch.int64, ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -383,7 +540,9 @@ def beam_search( if cached_key not in decoder_cache: decoder_input = torch.tensor( - [y_star.ys[-context_size:]], device=device + [y_star.ys[-context_size:]], + device=device, + dtype=torch.int64, ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -397,7 +556,7 @@ def beam_search( current_encoder_out, decoder_out.unsqueeze(1) ) - # TODO(fangjun): Cache the blank posterior + # TODO(fangjun): Scale the blank posterior log_prob = logits.log_softmax(dim=-1) # log_prob is (1, 1, 1, vocab_size) @@ -409,7 +568,7 @@ def beam_search( # First, process the blank symbol skip_log_prob = log_prob[blank_id] - new_y_star_log_prob = y_star.log_prob + skip_log_prob.item() + new_y_star_log_prob = y_star.log_prob + skip_log_prob # ys[:] returns a copy of ys B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) @@ -421,9 +580,8 @@ def beam_search( continue new_ys = y_star.ys + [i] new_log_prob = y_star.log_prob + v - A.add( - Hypothesis(ys=new_ys, log_prob=torch.tensor(new_log_prob)) - ) + A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) + # Check whether B contains more than "beam" elements more probable # than the most probable in A A_most_probable = A.get_most_probable() diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py index 1c25298d2..57901b0c6 100755 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py @@ -35,7 +35,7 @@ Usage: --decoding-method beam_search \ --beam-size 4 -(3) beam search +(3) modified beam search ./pruned_transducer_stateless/decode.py \ --epoch 29 \ --avg 13 \ @@ -43,20 +43,37 @@ Usage: --max-duration 100 \ --decoding-method modified_beam_search \ --beam-size 4 -""" +(4) fast beam search +./pruned_transducer_stateless/decode.py \ + --epoch 29 \ + --avg 13 \ + --exp-dir ./pruned_transducer_stateless/exp \ + --max-duration 1500 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 +""" import argparse import logging from collections import defaultdict from pathlib import Path -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple +import k2 import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import TedLiumAsrDataModule -from beam_search import beam_search, greedy_search, modified_beam_search +from beam_search import ( + beam_search, + fast_beam_search, + greedy_search, + greedy_search_batch, + modified_beam_search, +) from conformer import Conformer from decoder import Decoder from joiner import Joiner @@ -84,6 +101,7 @@ def get_parser(): help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) + parser.add_argument( "--avg", type=int, @@ -115,6 +133,7 @@ def get_parser(): - greedy_search - beam_search - modified_beam_search + - fast_beam_search """, ) @@ -122,8 +141,35 @@ def get_parser(): "--beam-size", type=int, default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, help="""Used only when --decoding-method is - beam_search""", + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search""", ) parser.add_argument( @@ -216,6 +262,7 @@ def decode_one_batch( model: nn.Module, sp: spm.SentencePieceProcessor, batch: dict, + decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -238,6 +285,9 @@ def decode_one_batch( It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -256,36 +306,72 @@ def decode_one_batch( x=feature, x_lens=feature_lens ) hyps = [] - batch_size = encoder_out.size(0) - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - elif params.decoding_method == "modified_beam_search": - hyp = modified_beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + elif params.decoding_method == "modified_beam_search": + hyp = modified_beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) if params.decoding_method == "greedy_search": return {"greedy_search": hyps} + elif params.decoding_method == "fast_beam_search": + return { + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps + } else: - return {f"beam_{params.beam_size}": hyps} + return {f"beam_size_{params.beam_size}": hyps} def decode_dataset( @@ -293,6 +379,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -305,6 +392,9 @@ def decode_dataset( The neural model. sp: The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -333,6 +423,7 @@ def decode_dataset( model=model, sp=sp, batch=batch, + decoding_graph=decoding_graph, ) for name, hyps in hyps_dict.items(): @@ -412,12 +503,17 @@ def main(): assert params.decoding_method in ( "greedy_search", "beam_search", + "fast_beam_search", "modified_beam_search", ) params.res_dir = params.exp_dir / params.decoding_method params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if "beam_search" in params.decoding_method: + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + elif "beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" @@ -461,6 +557,11 @@ def main(): model.eval() model.device = device + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -480,6 +581,7 @@ def main(): params=params, model=model, sp=sp, + decoding_graph=decoding_graph, ) save_results( diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/decoder.py b/egs/tedlium3/ASR/pruned_transducer_stateless/decoder.py index 4ffa8795d..8c7a269c3 100644 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/decoder.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/decoder.py @@ -65,6 +65,7 @@ class Decoder(nn.Module): self.unk_id = unk_id assert context_size >= 1, context_size self.context_size = context_size + self.vocab_size = vocab_size if context_size > 1: self.conv = nn.Conv1d( in_channels=embedding_dim, diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py index 09b19a642..2c795ede0 100644 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py @@ -130,6 +130,7 @@ def get_parser(): help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) + parser.add_argument( "--max-sym-per-frame", type=int,