From c30b8d3a1c095671d89321323390f8858722bcd1 Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Wed, 19 Oct 2022 16:53:29 +0800 Subject: [PATCH 01/11] fix number of parameters in RESULTS.md (#627) --- egs/librispeech/ASR/RESULTS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index d5a67b619..92323a556 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -1304,7 +1304,7 @@ results at: ##### Baseline-2 -It has 88.98 M parameters. Compared to the model in pruned_transducer_stateless2, its has more +It has 87.8 M parameters. Compared to the model in pruned_transducer_stateless2, its has more layers (24 v.s 12) but a narrower model (1536 feedforward dim and 384 encoder dim vs 2048 feed forward dim and 512 encoder dim). | | test-clean | test-other | comment | From 9b671e1c21c190f68183f05d33df1c134079ca18 Mon Sep 17 00:00:00 2001 From: ezerhouni <61225408+ezerhouni@users.noreply.github.com> Date: Fri, 21 Oct 2022 10:44:56 +0200 Subject: [PATCH 02/11] Add Shallow fusion in modified_beam_search (#630) * Add utility for shallow fusion * test batch size == 1 without shallow fusion * Use shallow fusion for modified-beam-search * Modified beam search with ngram rescoring * Fix code according to review Co-authored-by: Fangjun Kuang --- egs/librispeech/ASR/generate-lm.sh | 20 ++ .../ASR/lstm_transducer_stateless2/decode.py | 49 +++++ .../beam_search.py | 173 ++++++++++++++++++ icefall/__init__.py | 2 + icefall/ngram_lm.py | 164 +++++++++++++++++ test/test_ngram_lm.py | 68 +++++++ 6 files changed, 476 insertions(+) create mode 100755 egs/librispeech/ASR/generate-lm.sh create mode 100644 icefall/ngram_lm.py create mode 100755 test/test_ngram_lm.py diff --git a/egs/librispeech/ASR/generate-lm.sh b/egs/librispeech/ASR/generate-lm.sh new file mode 100755 index 000000000..6baccd381 --- /dev/null +++ b/egs/librispeech/ASR/generate-lm.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash + +lang_dir=data/lang_bpe_500 + +for ngram in 2 3 5; do + if [ ! -f $lang_dir/${ngram}gram.arpa ]; then + ./shared/make_kn_lm.py \ + -ngram-order ${ngram} \ + -text $lang_dir/transcript_tokens.txt \ + -lm $lang_dir/${ngram}gram.arpa + fi + + if [ ! -f $lang_dir/${ngram}gram.fst.txt ]; then + python3 -m kaldilm \ + --read-symbol-table="$lang_dir/tokens.txt" \ + --disambig-symbol='#0' \ + --max-order=${ngram} \ + $lang_dir/${ngram}gram.arpa > $lang_dir/${ngram}gram.fst.txt + fi +done diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index 420202cad..c7b53ebc0 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -115,10 +115,12 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, + modified_beam_search_ngram_rescoring, ) from librispeech import LibriSpeech from train import add_model_arguments, get_params, get_transducer_model +from icefall import NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -214,6 +216,7 @@ def get_parser(): - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG + - modified_beam_search_ngram_rescoring If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. """, @@ -303,6 +306,22 @@ def get_parser(): fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) + parser.add_argument( + "--tokens-ngram", + type=int, + default=3, + help="""Token Ngram used for rescoring. + Used only when the decoding method is modified_beam_search_ngram_rescoring""", + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="""ID of the backoff symbol. + Used only when the decoding method is modified_beam_search_ngram_rescoring""", + ) + add_model_arguments(parser) return parser @@ -315,6 +334,8 @@ def decode_one_batch( batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -448,6 +469,17 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_ngram_rescoring": + hyp_tokens = modified_beam_search_ngram_rescoring( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) else: batch_size = encoder_out.size(0) @@ -497,6 +529,8 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -546,6 +580,8 @@ def decode_dataset( decoding_graph=decoding_graph, word_table=word_table, batch=batch, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, ) for name, hyps in hyps_dict.items(): @@ -631,6 +667,7 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", + "modified_beam_search_ngram_rescoring", ) params.res_dir = params.exp_dir / params.decoding_method @@ -655,6 +692,7 @@ def main(): else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -768,6 +806,15 @@ def main(): model.to(device) model.eval() + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"lm filename: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + if "fast_beam_search" in params.decoding_method: if params.decoding_method == "fast_beam_search_nbest_LG": lexicon = Lexicon(params.lang_dir) @@ -812,6 +859,8 @@ def main(): sp=sp, word_table=word_table, decoding_graph=decoding_graph, + ngram_lm=ngram_lm, + ngram_lm_scale=params.ngram_lm_scale, ) save_results( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 769cd2a1d..c70618ef7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -23,6 +23,7 @@ import sentencepiece as spm import torch from model import Transducer +from icefall import NgramLm, NgramLmStateCost from icefall.decode import Nbest, one_best_decoding from icefall.utils import add_eos, add_sos, get_texts @@ -656,6 +657,8 @@ class Hypothesis: # It contains only one entry. log_prob: torch.Tensor + state_cost: Optional[NgramLmStateCost] = None + @property def key(self) -> str: """Return a string representation of self.ys""" @@ -1539,3 +1542,173 @@ def fast_beam_search_with_nbest_rnn_rescoring( ans[key] = hyps return ans + + +def modified_beam_search_ngram_rescoring( + model: Transducer, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ngram_lm: NgramLm, + ngram_lm_scale: float, + beam: int = 4, + temperature: float = 1.0, +) -> List[List[int]]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + beam: + Number of active paths during the beam search. + temperature: + Softmax temperature. + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + lm_scale = ngram_lm_scale + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + state_cost=NgramLmStateCost(ngram_lm), + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [ + hyp.log_prob.reshape(1, 1) + hyp.state_cost.lm_score * lm_scale + for hyps in A + for hyp in hyps + ] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = (logits / temperature).log_softmax( + dim=-1 + ) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + vocab_size = log_probs.size(-1) + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + state_cost = hyp.state_cost.forward_one_step(new_token) + else: + state_cost = hyp.state_cost + + # We only keep AM scores in new_hyp.log_prob + new_log_prob = ( + topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale + ) + + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, state_cost=state_cost + ) + B[i].add(new_hyp) + + B = B + finalized_B + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + + sorted_ans = [h.ys[context_size:] for h in best_hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans diff --git a/icefall/__init__.py b/icefall/__init__.py index 0399c8459..122226fdc 100644 --- a/icefall/__init__.py +++ b/icefall/__init__.py @@ -65,3 +65,5 @@ from .utils import ( subsequent_chunk_mask, write_error_stats, ) + +from .ngram_lm import NgramLm, NgramLmStateCost diff --git a/icefall/ngram_lm.py b/icefall/ngram_lm.py new file mode 100644 index 000000000..23185e35a --- /dev/null +++ b/icefall/ngram_lm.py @@ -0,0 +1,164 @@ +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from typing import List, Optional, Tuple + +import kaldifst + + +class NgramLm: + def __init__( + self, + fst_filename: str, + backoff_id: int, + is_binary: bool = False, + ): + """ + Args: + fst_filename: + Path to the FST. + backoff_id: + ID of the backoff symbol. + is_binary: + True if the given file is a binary FST. + """ + if is_binary: + lm = kaldifst.StdVectorFst.read(fst_filename) + else: + with open(fst_filename, "r") as f: + lm = kaldifst.compile(f.read(), acceptor=False) + + if not lm.is_ilabel_sorted: + kaldifst.arcsort(lm, sort_type="ilabel") + + self.lm = lm + self.backoff_id = backoff_id + + def _process_backoff_arcs( + self, + state: int, + cost: float, + ) -> List[Tuple[int, float]]: + """Similar to ProcessNonemitting() from Kaldi, this function + returns the list of states reachable from the given state via + backoff arcs. + + Args: + state: + The input state. + cost: + The cost of reaching the given state from the start state. + Returns: + Return a list, where each element contains a tuple with two entries: + - next_state + - cost of next_state + If there is no backoff arc leaving the input state, then return + an empty list. + """ + ans = [] + + next_state, next_cost = self._get_next_state_and_cost_without_backoff( + state=state, + label=self.backoff_id, + ) + if next_state is None: + return ans + ans.append((next_state, next_cost + cost)) + ans += self._process_backoff_arcs(next_state, next_cost + cost) + return ans + + def _get_next_state_and_cost_without_backoff( + self, state: int, label: int + ) -> Tuple[int, float]: + """TODO: Add doc.""" + arc_iter = kaldifst.ArcIterator(self.lm, state) + num_arcs = self.lm.num_arcs(state) + + # The LM is arc sorted by ilabel, so we use binary search below. + left = 0 + right = num_arcs - 1 + while left <= right: + mid = (left + right) // 2 + arc_iter.seek(mid) + arc = arc_iter.value + if arc.ilabel < label: + left = mid + 1 + elif arc.ilabel > label: + right = mid - 1 + else: + return arc.nextstate, arc.weight.value + + return None, None + + def get_next_state_and_cost( + self, + state: int, + label: int, + ) -> Tuple[List[int], List[float]]: + states = [state] + costs = [0] + + extra_states_costs = self._process_backoff_arcs( + state=state, + cost=0, + ) + + for s, c in extra_states_costs: + states.append(s) + costs.append(c) + + next_states = [] + next_costs = [] + for s, c in zip(states, costs): + ns, nc = self._get_next_state_and_cost_without_backoff(s, label) + if ns: + next_states.append(ns) + next_costs.append(c + nc) + + return next_states, next_costs + + +class NgramLmStateCost: + def __init__(self, ngram_lm: NgramLm, state_cost: Optional[dict] = None): + assert ngram_lm.lm.start == 0, ngram_lm.lm.start + self.ngram_lm = ngram_lm + if state_cost is not None: + self.state_cost = state_cost + else: + self.state_cost = defaultdict(lambda: float("inf")) + + # At the very beginning, we are at the start state with cost 0 + self.state_cost[0] = 0.0 + + def forward_one_step(self, label: int) -> "NgramLmStateCost": + state_cost = defaultdict(lambda: float("inf")) + for s, c in self.state_cost.items(): + next_states, next_costs = self.ngram_lm.get_next_state_and_cost( + s, + label, + ) + for ns, nc in zip(next_states, next_costs): + state_cost[ns] = min(state_cost[ns], c + nc) + + return NgramLmStateCost(ngram_lm=self.ngram_lm, state_cost=state_cost) + + @property + def lm_score(self) -> float: + if len(self.state_cost) == 0: + return float("-inf") + + return -1 * min(self.state_cost.values()) diff --git a/test/test_ngram_lm.py b/test/test_ngram_lm.py new file mode 100755 index 000000000..bbf6bd51c --- /dev/null +++ b/test/test_ngram_lm.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import graphviz +import kaldifst + +from icefall import NgramLm, NgramLmStateCost + + +def generate_fst(filename: str): + s = """ +3 5 1 1 3.00464 +3 0 3 0 5.75646 +0 1 1 1 12.0533 +0 2 2 2 7.95954 +0 9.97787 +1 4 2 2 3.35436 +1 0 3 0 7.59853 +2 0 3 0 +4 2 3 0 7.43735 +4 0.551239 +5 4 2 2 0.804938 +5 1 3 0 9.67086 +""" + fst = kaldifst.compile(s=s, acceptor=False) + fst.write(filename) + fst_dot = kaldifst.draw(fst, acceptor=False, portrait=True) + source = graphviz.Source(fst_dot) + source.render(outfile=f"{filename}.svg") + + +def main(): + filename = "test.fst" + generate_fst(filename) + ngram_lm = NgramLm(filename, backoff_id=3, is_binary=True) + for label in [1, 2, 3, 4, 5]: + print("---label---", label) + p = ngram_lm.get_next_state_and_cost(state=5, label=label) + print(p) + print("---") + + state_cost = NgramLmStateCost(ngram_lm) + s0 = state_cost.forward_one_step(1) + print(s0.state_cost) + + s1 = s0.forward_one_step(2) + print(s1.state_cost) + + s2 = s1.forward_one_step(2) + print(s2.state_cost) + + +if __name__ == "__main__": + main() From 348494888d08d5ddba2baadddcfe7df576d4bed1 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 22 Oct 2022 13:14:44 +0800 Subject: [PATCH 03/11] Add kaldifst to requirements.txt (#631) --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index d5931e49a..1c548c50a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ onnx onnxruntime --extra-index-url https://pypi.ngc.nvidia.com dill +kaldifst From 499ac24ecba64f687ff244c7d66baa5c222ecf0f Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 24 Oct 2022 15:07:29 +0800 Subject: [PATCH 04/11] Install kaldifst for GitHub actions (#632) * Install kaldifst for GitHub actions --- ...-librispeech-pruned-transducer-stateless3-2022-04-29.sh | 7 +++++-- requirements-ci.txt | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh index 172d7ad4c..00580ca1f 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh @@ -13,9 +13,12 @@ cd egs/librispeech/ASR repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-04-29 log "Downloading pre-trained model from $repo_url" -git lfs install -git clone $repo_url +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-epoch-25-avg-6.pt" +popd log "Display test files" tree $repo/ diff --git a/requirements-ci.txt b/requirements-ci.txt index 385c8737e..b8e49899e 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -23,4 +23,4 @@ multi_quantization onnx onnxruntime -onnx_graphsurgeon -i https://pypi.ngc.nvidia.com +kaldifst From 6709bf1e6325166fcb989b1dbb03344d6b90b7f8 Mon Sep 17 00:00:00 2001 From: Nagendra Goel Date: Thu, 27 Oct 2022 22:23:32 -0400 Subject: [PATCH 05/11] Update train.py (#635) Add the missing step to add the arguments to the parser. --- egs/librispeech/ASR/pruned_transducer_stateless3/train.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index 723b03e15..a8fd527d5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -327,7 +327,9 @@ def get_parser(): default=0.5, help="The probability to select a batch from the GigaSpeech dataset", ) - + + add_model_arguments(parser) + return parser From 581d0361cc739ce2d175e22804aac378d0155e5f Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Sun, 30 Oct 2022 16:35:30 +0800 Subject: [PATCH 06/11] Fix type hints for decode.py (#638) * Fix type hints for decode.py * Fix flake8 --- egs/librispeech/ASR/pruned_transducer_stateless/decode.py | 4 ++-- egs/librispeech/ASR/pruned_transducer_stateless3/train.py | 3 +-- egs/librispeech/ASR/pruned_transducer_stateless4/decode.py | 2 +- egs/librispeech/ASR/transducer/decode.py | 2 +- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index b11fb960a..3977f8443 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -511,7 +511,7 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[str, Tuple[str, List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -585,7 +585,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[str, Tuple[str, List[str], List[str]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index a8fd527d5..a74975caf 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -327,9 +327,8 @@ def get_parser(): default=0.5, help="The probability to select a batch from the GigaSpeech dataset", ) - + add_model_arguments(parser) - return parser diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index b75a72a15..873892bb9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -612,7 +612,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[str, Tuple[List[str], List[str]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/librispeech/ASR/transducer/decode.py b/egs/librispeech/ASR/transducer/decode.py index 24f243974..5f233df87 100755 --- a/egs/librispeech/ASR/transducer/decode.py +++ b/egs/librispeech/ASR/transducer/decode.py @@ -327,7 +327,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[str, Tuple[List[str], List[str]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): From 1abf2863bbcd35ca02f7282154e304424ddddeec Mon Sep 17 00:00:00 2001 From: Teo Wen Shen <36886809+teowenshen@users.noreply.github.com> Date: Sun, 30 Oct 2022 23:47:21 +0900 Subject: [PATCH 07/11] fix typos (#639) --- egs/csj/ASR/.gitignore | 5 +- egs/csj/ASR/local/compute_fbank_musan.py | 128 ++++++++++++++++++++++- egs/csj/ASR/prepare.sh | 24 ++--- 3 files changed, 142 insertions(+), 15 deletions(-) mode change 120000 => 100644 egs/csj/ASR/local/compute_fbank_musan.py diff --git a/egs/csj/ASR/.gitignore b/egs/csj/ASR/.gitignore index c0a162e20..5d965832e 100644 --- a/egs/csj/ASR/.gitignore +++ b/egs/csj/ASR/.gitignore @@ -1,7 +1,8 @@ -librispeech_*.* +librispeech_* todelete* lang* notify_tg.py finetune_* misc.ini -.vscode/* \ No newline at end of file +.vscode/* +offline/* \ No newline at end of file diff --git a/egs/csj/ASR/local/compute_fbank_musan.py b/egs/csj/ASR/local/compute_fbank_musan.py deleted file mode 120000 index 5833f2484..000000000 --- a/egs/csj/ASR/local/compute_fbank_musan.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/csj/ASR/local/compute_fbank_musan.py b/egs/csj/ASR/local/compute_fbank_musan.py new file mode 100644 index 000000000..44a33c4eb --- /dev/null +++ b/egs/csj/ASR/local/compute_fbank_musan.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os +from pathlib import Path + +import torch +from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, combine +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor + + +ARGPARSE_DESCRIPTION = """ +This file computes fbank features of the musan dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_fbank_musan(manifest_dir: Path, fbank_dir: Path): + # src_dir = Path("data/manifests") + # output_dir = Path("data/fbank") + num_jobs = min(15, os.cpu_count()) + num_mel_bins = 80 + + dataset_parts = ( + "music", + "speech", + "noise", + ) + prefix = "musan" + suffix = "jsonl.gz" + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=manifest_dir, + prefix=prefix, + suffix=suffix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + + musan_cuts_path = fbank_dir / "musan_cuts.jsonl.gz" + + if musan_cuts_path.is_file(): + logging.info(f"{musan_cuts_path} already exists - skipping") + return + + logging.info("Extracting features for Musan") + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + + with get_executor() as ex: # Initialize the executor only once. + # create chunks of Musan with duration 5 - 10 seconds + musan_cuts = ( + CutSet.from_manifests( + recordings=combine( + part["recordings"] for part in manifests.values() + ) + ) + .cut_into_windows(10.0) + .filter(lambda c: c.duration > 5) + .compute_and_store_features( + extractor=extractor, + storage_path=f"{fbank_dir}/musan_feats", + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + ) + musan_cuts.to_file(musan_cuts_path) + + +def get_args(): + parser = argparse.ArgumentParser( + description=ARGPARSE_DESCRIPTION, + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "--manifest-dir", type=Path, help="Path to save manifests" + ) + parser.add_argument( + "--fbank-dir", type=Path, help="Path to save fbank features" + ) + + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + compute_fbank_musan(args.manifest_dir, args.fbank_dir) diff --git a/egs/csj/ASR/prepare.sh b/egs/csj/ASR/prepare.sh index 269c1ec9a..052748ca6 100755 --- a/egs/csj/ASR/prepare.sh +++ b/egs/csj/ASR/prepare.sh @@ -2,7 +2,7 @@ # We assume the following directories are downloaded. # # - $csj_dir -# CSJ is assumed to be the USB-type directory, which should contain the following subdirectories:- +# CSJ is assumed to be the USB-type directory, which should contain the following subdirectories:- # - DATA (not used in this script) # - DOC (not used in this script) # - MODEL (not used in this script) @@ -30,7 +30,7 @@ # - music # - noise # - speech -# +# # By default, this script produces the original transcript like kaldi and espnet. Optionally, you # can generate other transcript formats by supplying your own config files. A few examples of these # config files can be found in local/conf. @@ -58,15 +58,15 @@ log() { echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" } -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then log "Stage 1: Prepare CSJ manifest" # If you want to generate more transcript modes, append the path to those config files at c. # Example: lhotse prepare csj $csj_dir $trans_dir $csj_manifest_dir -c local/conf/disfluent.ini # NOTE: In case multiple config files are supplied, the second config file and onwards will inherit - # the segment boundaries of the first config file. - if [ ! -e $csj_manifest_dir/.librispeech.done ]; then + # the segment boundaries of the first config file. + if [ ! -e $csj_manifest_dir/.csj.done ]; then lhotse prepare csj $csj_dir $trans_dir $csj_manifest_dir -j 4 - touch $csj_manifest_dir/.librispeech.done + touch $csj_manifest_dir/.csj.done fi fi @@ -85,20 +85,20 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then python local/compute_fbank_csj.py --manifest-dir $csj_manifest_dir \ --fbank-dir $csj_fbank_dir parts=( - train + train valid eval1 eval2 eval3 ) - for part in ${parts[@]}; do + for part in ${parts[@]}; do python local/validate_manifest.py --manifest $csj_manifest_dir/csj_cuts_$part.jsonl.gz done touch $csj_fbank_dir/.csj-validated.done fi fi -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then log "Stage 4: Prepare CSJ lang" modes=disfluent @@ -117,14 +117,14 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 5: Compute fbank for musan" mkdir -p $musan_fbank_dir - if [ ! -e $musan_fbank_dir/.musan.done ]; then + if [ ! -e $musan_fbank_dir/.musan.done ]; then python local/compute_fbank_musan.py --manifest-dir $musan_manifest_dir --fbank-dir $musan_fbank_dir touch $musan_fbank_dir/.musan.done fi fi -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then log "Stage 6: Show manifest statistics" python local/display_manifest_statistics.py --manifest-dir $csj_manifest_dir > $csj_manifest_dir/manifest_statistics.txt cat $csj_manifest_dir/manifest_statistics.txt -fi \ No newline at end of file +fi From 7f1c0e07b6daa058171cc4bf26233d023a2be10c Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 31 Oct 2022 13:44:40 +0800 Subject: [PATCH 08/11] Remove onnx and onnxruntime from requirements.txt (#640) * Remove onnx and onnxruntime from requirements.txt --- .../streaming-onnx-decode.py | 5 +++++ .../ASR/pruned_transducer_stateless3/onnx_check.py | 5 +++++ .../ASR/pruned_transducer_stateless3/test_onnx.py | 5 +++++ .../ASR/pruned_transducer_stateless6/model.py | 9 ++++++++- .../ASR/pruned_transducer_stateless6/vq_utils.py | 13 ++++++++----- .../ASR/pruned_transducer_stateless2/onnx_check.py | 5 +++++ .../onnx_pretrained.py | 6 ++++++ icefall/__init__.py | 1 + icefall/ngram_lm.py | 9 ++++++++- icefall/utils.py | 14 ++++++++++++++ requirements.txt | 5 ----- test/test_ngram_lm.py | 6 ++++++ 12 files changed, 71 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py index 1c9ec3e89..232d3dd18 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py @@ -42,6 +42,11 @@ import argparse import logging from typing import List, Optional, Tuple +from icefall import is_module_available + +if not is_module_available("onnxruntime"): + raise ValueError("Please 'pip install onnxruntime' first.") + import onnxruntime as ort import sentencepiece as spm import torch diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py index fb9adb44a..d03d1d7ef 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py @@ -24,6 +24,11 @@ with the given torchscript model for the same input. import argparse import logging +from icefall import is_module_available + +if not is_module_available("onnxruntime"): + raise ValueError("Please 'pip install onnxruntime' first.") + import onnxruntime as ort import torch diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py index c55268b14..66ffbd3ec 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py @@ -21,6 +21,11 @@ This file is to test that models can be exported to onnx. """ import os +from icefall import is_module_available + +if not is_module_available("onnxruntime"): + raise ValueError("Please 'pip install onnxruntime' first.") + import onnxruntime as ort import torch from conformer import ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py index 06c4b5204..7716d19cf 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py @@ -21,7 +21,6 @@ import k2 import torch import torch.nn as nn from encoder_interface import EncoderInterface -from multi_quantization.prediction import JointCodebookLoss from scaling import ScaledLinear from icefall.utils import add_sos @@ -74,6 +73,14 @@ class Transducer(nn.Module): encoder_dim, vocab_size, initial_speed=0.5 ) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) + + from icefall import is_module_available + + if not is_module_available("multi_quantization"): + raise ValueError("Please 'pip install multi_quantization' first.") + + from multi_quantization.prediction import JointCodebookLoss + if num_codebooks > 0: self.codebook_loss_net = JointCodebookLoss( predictor_channels=encoder_dim, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py index 65895c920..47cf2b14b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py @@ -28,18 +28,21 @@ from typing import List, Tuple import numpy as np import torch import torch.multiprocessing as mp -import multi_quantization as quantization +from icefall import is_module_available + +if not is_module_available("multi_quantization"): + raise ValueError("Please 'pip install multi_quantization' first.") + +import multi_quantization as quantization from asr_datamodule import LibriSpeechAsrDataModule from hubert_xlarge import HubertXlargeFineTuned -from icefall.utils import ( - AttributeDict, - setup_logger, -) from lhotse import CutSet, load_manifest from lhotse.cut import MonoCut from lhotse.features.io import NumpyHdf5Writer +from icefall.utils import AttributeDict, setup_logger + class CodebookIndexExtractor: """ diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py index 91877ec46..c396c50ef 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py @@ -40,6 +40,11 @@ https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_s import argparse import logging +from icefall import is_module_available + +if not is_module_available("onnxruntime"): + raise ValueError("Please 'pip install onnxruntime' first.") + import onnxruntime as ort import torch diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py index 132517352..3770fbbb4 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py @@ -49,6 +49,12 @@ from typing import List import k2 import kaldifeat import numpy as np + +from icefall import is_module_available + +if not is_module_available("onnxruntime"): + raise ValueError("Please 'pip install onnxruntime' first.") + import onnxruntime as ort import torch import torchaudio diff --git a/icefall/__init__.py b/icefall/__init__.py index 122226fdc..27ad74213 100644 --- a/icefall/__init__.py +++ b/icefall/__init__.py @@ -50,6 +50,7 @@ from .utils import ( get_executor, get_texts, is_jit_tracing, + is_module_available, l1_norm, l2_norm, linf_norm, diff --git a/icefall/ngram_lm.py b/icefall/ngram_lm.py index 23185e35a..63885a9d0 100644 --- a/icefall/ngram_lm.py +++ b/icefall/ngram_lm.py @@ -17,7 +17,7 @@ from collections import defaultdict from typing import List, Optional, Tuple -import kaldifst +from icefall.utils import is_module_available class NgramLm: @@ -36,6 +36,11 @@ class NgramLm: is_binary: True if the given file is a binary FST. """ + if not is_module_available("kaldifst"): + raise ValueError("Please 'pip install kaldifst' first.") + + import kaldifst + if is_binary: lm = kaldifst.StdVectorFst.read(fst_filename) else: @@ -85,6 +90,8 @@ class NgramLm: self, state: int, label: int ) -> Tuple[int, float]: """TODO: Add doc.""" + import kaldifst + arc_iter = kaldifst.ArcIterator(self.lm, state) num_arcs = self.lm.num_arcs(state) diff --git a/icefall/utils.py b/icefall/utils.py index ad079222e..6c115ed16 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -976,3 +976,17 @@ def display_and_save_batch( y = sp.encode(supervisions["text"], out_type=int) num_tokens = sum(len(i) for i in y) logging.info(f"num tokens: {num_tokens}") + + +# `is_module_available` is copied from +# https://github.com/pytorch/audio/blob/6bad3a66a7a1c7cc05755e9ee5931b7391d2b94c/torchaudio/_internal/module_utils.py#L9 +def is_module_available(*modules: str) -> bool: + r"""Returns if a top-level module with :attr:`name` exists *without** + importing it. This is generally safer than try-catch block around a + `import X`. + + Note: "borrowed" from torchaudio: + """ + import importlib + + return all(importlib.util.find_spec(m) is not None for m in modules) diff --git a/requirements.txt b/requirements.txt index 1c548c50a..5e32af853 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,9 +3,4 @@ kaldialign sentencepiece>=0.1.96 tensorboard typeguard -multi_quantization -onnx -onnxruntime ---extra-index-url https://pypi.ngc.nvidia.com dill -kaldifst diff --git a/test/test_ngram_lm.py b/test/test_ngram_lm.py index bbf6bd51c..838c792d2 100755 --- a/test/test_ngram_lm.py +++ b/test/test_ngram_lm.py @@ -16,6 +16,12 @@ # limitations under the License. import graphviz + +from icefall import is_module_available + +if not is_module_available("kaldifst"): + raise ValueError("Please 'pip install kaldifst' first.") + import kaldifst from icefall import NgramLm, NgramLmStateCost From ff3f0263812aa7f738d976aac2d3393b3a98bcc3 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 31 Oct 2022 19:47:43 +0800 Subject: [PATCH 09/11] Checkout the LM for aishell explicitly (#642) --- egs/aishell/ASR/prepare.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index f86dd8de3..eaeecfc4a 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -52,6 +52,9 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then if [ ! -f $dl_dir/lm/3-gram.unpruned.arpa ]; then git clone https://huggingface.co/pkufool/aishell_lm $dl_dir/lm + pushd $dl_dir/lm + git lfs pull --include "3-gram.unpruned.arpa" + popd fi fi From 03668771d765febba32083a6a376b80a1331c1c4 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Tue, 1 Nov 2022 10:24:00 +0800 Subject: [PATCH 10/11] Get timestamps during decoding (#598) * print out timestamps during decoding * add word-level alignments * support to compute mean symbol delay with word-level alignments * print variance of symbol delay * update doc * support to compute delay for pruned_transducer_stateless4 * fix bug * add doc --- egs/librispeech/ASR/add_alignments.sh | 12 + .../ASR/local/add_alignment_librispeech.py | 196 ++++++++ .../ASR/lstm_transducer_stateless3/decode.py | 168 +++++-- .../ASR/lstm_transducer_stateless3/train.py | 1 + .../beam_search.py | 251 +++++++--- .../pruned_transducer_stateless4/decode.py | 169 +++++-- .../ASR/pruned_transducer_stateless4/train.py | 1 + icefall/utils.py | 446 +++++++++++++++++- 8 files changed, 1094 insertions(+), 150 deletions(-) create mode 100755 egs/librispeech/ASR/add_alignments.sh create mode 100755 egs/librispeech/ASR/local/add_alignment_librispeech.py diff --git a/egs/librispeech/ASR/add_alignments.sh b/egs/librispeech/ASR/add_alignments.sh new file mode 100755 index 000000000..5e4480bf6 --- /dev/null +++ b/egs/librispeech/ASR/add_alignments.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash + +set -eou pipefail + +alignments_dir=data/alignment +cuts_in_dir=data/fbank +cuts_out_dir=data/fbank_ali + +python3 ./local/add_alignment_librispeech.py \ + --alignments-dir $alignments_dir \ + --cuts-in-dir $cuts_in_dir \ + --cuts-out-dir $cuts_out_dir diff --git a/egs/librispeech/ASR/local/add_alignment_librispeech.py b/egs/librispeech/ASR/local/add_alignment_librispeech.py new file mode 100755 index 000000000..cd1bcea67 --- /dev/null +++ b/egs/librispeech/ASR/local/add_alignment_librispeech.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file adds alignments from https://github.com/CorentinJ/librispeech-alignments # noqa +to the existing fbank features dir (e.g., data/fbank) +and save cuts to a new dir (e.g., data/fbank_ali). +""" + +import argparse +import logging +import zipfile +from pathlib import Path +from typing import List + +from lhotse import CutSet, load_manifest_lazy +from lhotse.recipes.librispeech import parse_alignments +from lhotse.utils import is_module_available + +LIBRISPEECH_ALIGNMENTS_URL = ( + "https://drive.google.com/uc?id=1WYfgr31T-PPwMcxuAq09XZfHQO5Mw8fE" +) + +DATASET_PARTS = [ + "dev-clean", + "dev-other", + "test-clean", + "test-other", + "train-clean-100", + "train-clean-360", + "train-other-500", +] + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--alignments-dir", + type=str, + default="data/alignment", + help="The dir to save alignments.", + ) + + parser.add_argument( + "--cuts-in-dir", + type=str, + default="data/fbank", + help="The dir of the existing cuts without alignments.", + ) + + parser.add_argument( + "--cuts-out-dir", + type=str, + default="data/fbank_ali", + help="The dir to save the new cuts with alignments", + ) + + return parser + + +def download_alignments( + target_dir: str, alignments_url: str = LIBRISPEECH_ALIGNMENTS_URL +): + """ + Download and extract the alignments. + + Note: If you can not access drive.google.com, you could download the file + `LibriSpeech-Alignments.zip` from huggingface: + https://huggingface.co/Zengwei/librispeech-alignments + and extract the zip file manually. + + Args: + target_dir: + The dir to save alignments. + alignments_url: + The URL of alignments. + """ + """Modified from https://github.com/lhotse-speech/lhotse/blob/master/lhotse/recipes/librispeech.py""" # noqa + target_dir = Path(target_dir) + target_dir.mkdir(parents=True, exist_ok=True) + completed_detector = target_dir / ".ali_completed" + if completed_detector.is_file(): + logging.info("The alignment files already exist.") + return + + ali_zip_path = target_dir / "LibriSpeech-Alignments.zip" + if not ali_zip_path.is_file(): + assert is_module_available( + "gdown" + ), 'To download LibriSpeech alignments, please install "pip install gdown"' # noqa + import gdown + + gdown.download(alignments_url, output=str(ali_zip_path)) + + with zipfile.ZipFile(str(ali_zip_path)) as f: + f.extractall(path=target_dir) + completed_detector.touch() + + +def add_alignment( + alignments_dir: str, + cuts_in_dir: str = "data/fbank", + cuts_out_dir: str = "data/fbank_ali", + dataset_parts: List[str] = DATASET_PARTS, +): + """ + Add alignment info to existing cuts. + + Args: + alignments_dir: + The dir of the alignments. + cuts_in_dir: + The dir of the existing cuts. + cuts_out_dir: + The dir to save the new cuts with alignments. + dataset_parts: + Librispeech parts to add alignments. + """ + alignments_dir = Path(alignments_dir) + cuts_in_dir = Path(cuts_in_dir) + cuts_out_dir = Path(cuts_out_dir) + cuts_out_dir.mkdir(parents=True, exist_ok=True) + + for part in dataset_parts: + logging.info(f"Processing {part}") + + cuts_in_path = cuts_in_dir / f"librispeech_cuts_{part}.jsonl.gz" + if not cuts_in_path.is_file(): + logging.info(f"{cuts_in_path} does not exist - skipping.") + continue + cuts_out_path = cuts_out_dir / f"librispeech_cuts_{part}.jsonl.gz" + if cuts_out_path.is_file(): + logging.info(f"{part} already exists - skipping.") + continue + + # parse alignments + alignments = {} + part_ali_dir = alignments_dir / "LibriSpeech" / part + for ali_path in part_ali_dir.rglob("*.alignment.txt"): + ali = parse_alignments(ali_path) + alignments.update(ali) + logging.info( + f"{part} has {len(alignments.keys())} cuts with alignments." + ) + + # add alignment attribute and write out + cuts_in = load_manifest_lazy(cuts_in_path) + with CutSet.open_writer(cuts_out_path) as writer: + for cut in cuts_in: + for idx, subcut in enumerate(cut.supervisions): + origin_id = subcut.id.split("_")[0] + if origin_id in alignments: + ali = alignments[origin_id] + else: + logging.info( + f"Warning: {origin_id} does not has alignment." + ) + ali = [] + subcut.alignment = {"word": ali} + writer.write(cut, flush=True) + + +def main(): + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + logging.basicConfig(format=formatter, level=logging.INFO) + + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + download_alignments(args.alignments_dir) + add_alignment(args.alignments_dir, args.cuts_in_dir, args.cuts_out_dir) + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py index 5be23c50c..052d027e3 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py @@ -91,6 +91,22 @@ Usage: --beam 20.0 \ --max-contexts 8 \ --max-states 64 + +To evaluate symbol delay, you should: +(1) Generate cuts with word-time alignments: +./local/add_alignment_librispeech.py \ + --alignments-dir data/alignment \ + --cuts-in-dir data/fbank \ + --cuts-out-dir data/fbank_ali +(2) Set the argument "--manifest-dir data/fbank_ali" while decoding. +For example: +./lstm_transducer_stateless3/decode.py \ + --epoch 40 \ + --avg 20 \ + --exp-dir ./lstm_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method greedy_search \ + --manifest-dir data/fbank_ali """ @@ -127,10 +143,12 @@ from icefall.checkpoint import ( from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, + DecodingResults, + parse_hyp_and_timestamp, setup_logger, - store_transcripts, + store_transcripts_and_timestamps, str2bool, - write_error_stats, + write_error_stats_with_timestamps, ) LOG_EPS = math.log(1e-10) @@ -314,7 +332,7 @@ def decode_one_batch( batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: +) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -322,9 +340,11 @@ def decode_one_batch( if greedy_search is used, it would be "greedy_search" If beam search with a beam size of 7 is used, it would be "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. + - value: It is a tuple. `len(value[0])` and `len(value[1])` are both + equal to the batch size. `value[0][i]` and `value[1][i]` + are the decoding result and timestamps for the i-th utterance + in the given batch respectively. + Args: params: It's the return value of :func:`get_params`. @@ -343,8 +363,8 @@ def decode_one_batch( only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: - Return the decoding result. See above description for the format of - the returned dict. + Return the decoding result and timestamps. See above description for the + format of the returned dict. """ device = next(model.parameters()).device feature = batch["inputs"] @@ -370,10 +390,8 @@ def decode_one_batch( x=feature, x_lens=feature_lens ) - hyps = [] - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( + res = fast_beam_search_one_best( model=model, decoding_graph=decoding_graph, encoder_out=encoder_out, @@ -381,11 +399,10 @@ def decode_one_batch( beam=params.beam, max_contexts=params.max_contexts, max_states=params.max_states, + return_timestamps=True, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( + res = fast_beam_search_nbest_LG( model=model, decoding_graph=decoding_graph, encoder_out=encoder_out, @@ -395,11 +412,10 @@ def decode_one_batch( max_states=params.max_states, num_paths=params.num_paths, nbest_scale=params.nbest_scale, + return_timestamps=True, ) - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( + res = fast_beam_search_nbest( model=model, decoding_graph=decoding_graph, encoder_out=encoder_out, @@ -409,11 +425,10 @@ def decode_one_batch( max_states=params.max_states, num_paths=params.num_paths, nbest_scale=params.nbest_scale, + return_timestamps=True, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( + res = fast_beam_search_nbest_oracle( model=model, decoding_graph=decoding_graph, encoder_out=encoder_out, @@ -424,56 +439,67 @@ def decode_one_batch( num_paths=params.num_paths, ref_texts=sp.encode(supervisions["text"]), nbest_scale=params.nbest_scale, + return_timestamps=True, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) elif ( params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1 ): - hyp_tokens = greedy_search_batch( + res = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, + return_timestamps=True, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( + res = modified_beam_search( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, + return_timestamps=True, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) else: batch_size = encoder_out.size(0) - + tokens = [] + timestamps = [] for i in range(batch_size): # fmt: off encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] # fmt: on if params.decoding_method == "greedy_search": - hyp = greedy_search( + res = greedy_search( model=model, encoder_out=encoder_out_i, max_sym_per_frame=params.max_sym_per_frame, + return_timestamps=True, ) elif params.decoding_method == "beam_search": - hyp = beam_search( + res = beam_search( model=model, encoder_out=encoder_out_i, beam=params.beam_size, + return_timestamps=True, ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" ) - hyps.append(sp.decode(hyp).split()) + tokens.extend(res.tokens) + timestamps.extend(res.timestamps) + res = DecodingResults(tokens=tokens, timestamps=timestamps) + + hyps, timestamps = parse_hyp_and_timestamp( + decoding_method=params.decoding_method, + res=res, + sp=sp, + subsampling_factor=params.subsampling_factor, + frame_shift_ms=params.frame_shift_ms, + word_table=word_table, + ) if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} + return {"greedy_search": (hyps, timestamps)} elif "fast_beam_search" in params.decoding_method: key = f"beam_{params.beam}_" key += f"max_contexts_{params.max_contexts}_" @@ -484,9 +510,9 @@ def decode_one_batch( if "LG" in params.decoding_method: key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - return {key: hyps} + return {key: (hyps, timestamps)} else: - return {f"beam_size_{params.beam_size}": hyps} + return {f"beam_size_{params.beam_size}": (hyps, timestamps)} def decode_dataset( @@ -496,7 +522,9 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: +) -> Dict[ + str, List[Tuple[str, List[str], List[str], List[float], List[float]]] +]: """Decode dataset. Args: @@ -517,9 +545,12 @@ def decode_dataset( Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. + Its value is a list of tuples. Each tuple contains five elements: + - cut_id + - reference transcript + - predicted result + - timestamp of reference transcript + - timestamp of predicted result """ num_cuts = 0 @@ -538,6 +569,18 @@ def decode_dataset( texts = batch["supervisions"]["text"] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + timestamps_ref = [] + for cut in batch["supervisions"]["cut"]: + for s in cut.supervisions: + time = [] + if s.alignment is not None and "word" in s.alignment: + time = [ + aliword.start + for aliword in s.alignment["word"] + if aliword.symbol != "" + ] + timestamps_ref.append(time) + hyps_dict = decode_one_batch( params=params, model=model, @@ -547,12 +590,18 @@ def decode_dataset( batch=batch, ) - for name, hyps in hyps_dict.items(): + for name, (hyps, timestamps_hyp) in hyps_dict.items(): this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + assert len(hyps) == len(texts) and len(timestamps_hyp) == len( + timestamps_ref + ) + for cut_id, hyp_words, ref_text, time_hyp, time_ref in zip( + cut_ids, hyps, texts, timestamps_hyp, timestamps_ref + ): ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) + this_batch.append( + (cut_id, ref_words, hyp_words, time_ref, time_hyp) + ) results[name].extend(this_batch) @@ -570,15 +619,19 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], + results_dict: Dict[ + str, + List[Tuple[List[str], List[str], List[str], List[float], List[float]]], + ], ): test_set_wers = dict() + test_set_delays = dict() for key, results in results_dict.items(): recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) results = sorted(results) - store_transcripts(filename=recog_path, texts=results) + store_transcripts_and_timestamps(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned @@ -587,10 +640,11 @@ def save_results( params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: - wer = write_error_stats( + wer, mean_delay, var_delay = write_error_stats_with_timestamps( f, f"{test_set_name}-{key}", results, enable_log=True ) test_set_wers[key] = wer + test_set_delays[key] = (mean_delay, var_delay) logging.info("Wrote detailed error stats to {}".format(errs_filename)) @@ -604,6 +658,19 @@ def save_results( for key, val in test_set_wers: print("{}\t{}".format(key, val), file=f) + test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0]) + delays_info = ( + params.res_dir + / f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(delays_info, "w") as f: + print("settings\tsymbol-delay", file=f) + for key, val in test_set_delays: + print( + "{}\tmean: {}s, variance: {}".format(key, val[0], val[1]), + file=f, + ) + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) note = "\tbest for {}".format(test_set_name) for key, val in test_set_wers: @@ -611,6 +678,15 @@ def save_results( note = "" logging.info(s) + s = "\nFor {}, symbol-delay of different settings are:\n".format( + test_set_name + ) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_delays: + s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) + note = "" + logging.info(s) + @torch.no_grad() def main(): diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py index dc3697ae7..fa50576d8 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py @@ -377,6 +377,7 @@ def get_params() -> AttributeDict: """ params = AttributeDict( { + "frame_shift_ms": 10.0, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index c70618ef7..0004a24eb 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -16,7 +16,7 @@ import warnings from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union import k2 import sentencepiece as spm @@ -25,7 +25,13 @@ from model import Transducer from icefall import NgramLm, NgramLmStateCost from icefall.decode import Nbest, one_best_decoding -from icefall.utils import add_eos, add_sos, get_texts +from icefall.utils import ( + DecodingResults, + add_eos, + add_sos, + get_texts, + get_texts_with_timestamp, +) def fast_beam_search_one_best( @@ -37,7 +43,8 @@ def fast_beam_search_one_best( max_states: int, max_contexts: int, temperature: float = 1.0, -) -> List[List[int]]: + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using fast beam search, and then @@ -61,8 +68,12 @@ def fast_beam_search_one_best( Max contexts pre stream per frame. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ lattice = fast_beam_search( model=model, @@ -76,8 +87,11 @@ def fast_beam_search_one_best( ) best_path = one_best_decoding(lattice) - hyps = get_texts(best_path) - return hyps + + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) def fast_beam_search_nbest_LG( @@ -92,7 +106,8 @@ def fast_beam_search_nbest_LG( nbest_scale: float = 0.5, use_double_scores: bool = True, temperature: float = 1.0, -) -> List[List[int]]: + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: """It limits the maximum number of symbols per frame to 1. The process to get the results is: @@ -129,8 +144,12 @@ def fast_beam_search_nbest_LG( single precision. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ lattice = fast_beam_search( model=model, @@ -195,9 +214,10 @@ def fast_beam_search_nbest_LG( best_hyp_indexes = ragged_tot_scores.argmax() best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes) - hyps = get_texts(best_path) - - return hyps + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) def fast_beam_search_nbest( @@ -212,7 +232,8 @@ def fast_beam_search_nbest( nbest_scale: float = 0.5, use_double_scores: bool = True, temperature: float = 1.0, -) -> List[List[int]]: + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: """It limits the maximum number of symbols per frame to 1. The process to get the results is: @@ -249,8 +270,12 @@ def fast_beam_search_nbest( single precision. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ lattice = fast_beam_search( model=model, @@ -279,9 +304,10 @@ def fast_beam_search_nbest( best_path = k2.index_fsa(nbest.fsa, max_indexes) - hyps = get_texts(best_path) - - return hyps + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) def fast_beam_search_nbest_oracle( @@ -297,7 +323,8 @@ def fast_beam_search_nbest_oracle( use_double_scores: bool = True, nbest_scale: float = 0.5, temperature: float = 1.0, -) -> List[List[int]]: + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using fast beam search, and then @@ -338,8 +365,12 @@ def fast_beam_search_nbest_oracle( yields more unique paths. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ lattice = fast_beam_search( model=model, @@ -378,8 +409,10 @@ def fast_beam_search_nbest_oracle( best_path = k2.index_fsa(nbest.fsa, max_indexes) - hyps = get_texts(best_path) - return hyps + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) def fast_beam_search( @@ -469,8 +502,11 @@ def fast_beam_search( def greedy_search( - model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int -) -> List[int]: + model: Transducer, + encoder_out: torch.Tensor, + max_sym_per_frame: int, + return_timestamps: bool = False, +) -> Union[List[int], DecodingResults]: """Greedy search for a single utterance. Args: model: @@ -480,8 +516,12 @@ def greedy_search( max_sym_per_frame: Maximum number of symbols per frame. If it is set to 0, the WER would be 100%. + return_timestamps: + Whether to return timestamps. Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ assert encoder_out.ndim == 3 @@ -507,6 +547,10 @@ def greedy_search( t = 0 hyp = [blank_id] * context_size + # timestamp[i] is the frame index after subsampling + # on which hyp[i] is decoded + timestamp = [] + # Maximum symbols per utterance. max_sym_per_utt = 1000 @@ -533,6 +577,7 @@ def greedy_search( y = logits.argmax().item() if y not in (blank_id, unk_id): hyp.append(y) + timestamp.append(t) decoder_input = torch.tensor( [hyp[-context_size:]], device=device ).reshape(1, context_size) @@ -547,14 +592,21 @@ def greedy_search( t += 1 hyp = hyp[context_size:] # remove blanks - return hyp + if not return_timestamps: + return hyp + else: + return DecodingResults( + tokens=[hyp], + timestamps=[timestamp], + ) def greedy_search_batch( model: Transducer, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, -) -> List[List[int]]: + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. Args: model: @@ -564,9 +616,12 @@ def greedy_search_batch( encoder_out_lens: A 1-D tensor of shape (N,), containing number of valid frames in encoder_out before padding. + return_timestamps: + Whether to return timestamps. Returns: - Return a list-of-list of token IDs containing the decoded results. - len(ans) equals to encoder_out.size(0). + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ assert encoder_out.ndim == 3 assert encoder_out.size(0) >= 1, encoder_out.size(0) @@ -591,6 +646,10 @@ def greedy_search_batch( hyps = [[blank_id] * context_size for _ in range(N)] + # timestamp[n][i] is the frame index after subsampling + # on which hyp[n][i] is decoded + timestamps = [[] for _ in range(N)] + decoder_input = torch.tensor( hyps, device=device, @@ -604,7 +663,7 @@ def greedy_search_batch( encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) offset = 0 - for batch_size in batch_size_list: + for (t, batch_size) in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] @@ -626,6 +685,7 @@ def greedy_search_batch( for i, v in enumerate(y): if v not in (blank_id, unk_id): hyps[i].append(v) + timestamps[i].append(t) emitted = True if emitted: # update decoder output @@ -640,11 +700,19 @@ def greedy_search_batch( sorted_ans = [h[context_size:] for h in hyps] ans = [] + ans_timestamps = [] unsorted_indices = packed_encoder_out.unsorted_indices.tolist() for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) + ans_timestamps.append(timestamps[unsorted_indices[i]]) - return ans + if not return_timestamps: + return ans + else: + return DecodingResults( + tokens=ans, + timestamps=ans_timestamps, + ) @dataclass @@ -657,6 +725,10 @@ class Hypothesis: # It contains only one entry. log_prob: torch.Tensor + # timestamp[i] is the frame index after subsampling + # on which ys[i] is decoded + timestamp: List[int] + state_cost: Optional[NgramLmStateCost] = None @property @@ -806,7 +878,8 @@ def modified_beam_search( encoder_out_lens: torch.Tensor, beam: int = 4, temperature: float = 1.0, -) -> List[List[int]]: + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. Args: @@ -821,9 +894,12 @@ def modified_beam_search( Number of active paths during the beam search. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: - Return a list-of-list of token IDs. ans[i] is the decoding results - for the i-th utterance. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ assert encoder_out.ndim == 3, encoder_out.shape assert encoder_out.size(0) >= 1, encoder_out.size(0) @@ -851,6 +927,7 @@ def modified_beam_search( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], ) ) @@ -858,7 +935,7 @@ def modified_beam_search( offset = 0 finalized_B = [] - for batch_size in batch_size_list: + for (t, batch_size) in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] @@ -936,30 +1013,44 @@ def modified_beam_search( new_ys = hyp.ys[:] new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] if new_token not in (blank_id, unk_id): new_ys.append(new_token) + new_timestamp.append(t) new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp + ) B[i].add(new_hyp) B = B + finalized_B best_hyps = [b.get_most_probable(length_norm=True) for b in B] sorted_ans = [h.ys[context_size:] for h in best_hyps] + sorted_timestamps = [h.timestamp for h in best_hyps] ans = [] + ans_timestamps = [] unsorted_indices = packed_encoder_out.unsorted_indices.tolist() for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) + ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) - return ans + if not return_timestamps: + return ans + else: + return DecodingResults( + tokens=ans, + timestamps=ans_timestamps, + ) def _deprecated_modified_beam_search( model: Transducer, encoder_out: torch.Tensor, beam: int = 4, -) -> List[int]: + return_timestamps: bool = False, +) -> Union[List[int], DecodingResults]: """It limits the maximum number of symbols per frame to 1. It decodes only one utterance at a time. We keep it only for reference. @@ -974,8 +1065,13 @@ def _deprecated_modified_beam_search( A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. beam: Beam size. + return_timestamps: + Whether to return timestamps. + Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ assert encoder_out.ndim == 3 @@ -995,6 +1091,7 @@ def _deprecated_modified_beam_search( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], ) ) encoder_out = model.joiner.encoder_proj(encoder_out) @@ -1053,17 +1150,24 @@ def _deprecated_modified_beam_search( for i in range(len(topk_hyp_indexes)): hyp = A[topk_hyp_indexes[i]] new_ys = hyp.ys[:] + new_timestamp = hyp.timestamp[:] new_token = topk_token_indexes[i] if new_token not in (blank_id, unk_id): new_ys.append(new_token) + new_timestamp.append(t) new_log_prob = topk_log_probs[i] - new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp + ) B.add(new_hyp) best_hyp = B.get_most_probable(length_norm=True) ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - return ys + if not return_timestamps: + return ys + else: + return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp]) def beam_search( @@ -1071,7 +1175,8 @@ def beam_search( encoder_out: torch.Tensor, beam: int = 4, temperature: float = 1.0, -) -> List[int]: + return_timestamps: bool = False, +) -> Union[List[int], DecodingResults]: """ It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf @@ -1086,8 +1191,13 @@ def beam_search( Beam size. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ assert encoder_out.ndim == 3 @@ -1114,7 +1224,7 @@ def beam_search( t = 0 B = HypothesisList() - B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) + B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[])) max_sym_per_utt = 20000 @@ -1175,7 +1285,13 @@ def beam_search( new_y_star_log_prob = y_star.log_prob + skip_log_prob # ys[:] returns a copy of ys - B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) + B.add( + Hypothesis( + ys=y_star.ys[:], + log_prob=new_y_star_log_prob, + timestamp=y_star.timestamp[:], + ) + ) # Second, process other non-blank labels values, indices = log_prob.topk(beam + 1) @@ -1184,7 +1300,14 @@ def beam_search( continue new_ys = y_star.ys + [i] new_log_prob = y_star.log_prob + v - A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) + new_timestamp = y_star.timestamp + [t] + A.add( + Hypothesis( + ys=new_ys, + log_prob=new_log_prob, + timestamp=new_timestamp, + ) + ) # Check whether B contains more than "beam" elements more probable # than the most probable in A @@ -1200,7 +1323,11 @@ def beam_search( best_hyp = B.get_most_probable(length_norm=True) ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - return ys + + if not return_timestamps: + return ys + else: + return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp]) def fast_beam_search_with_nbest_rescoring( @@ -1220,7 +1347,8 @@ def fast_beam_search_with_nbest_rescoring( use_double_scores: bool = True, nbest_scale: float = 0.5, temperature: float = 1.0, -) -> Dict[str, List[List[int]]]: + return_timestamps: bool = False, +) -> Dict[str, Union[List[List[int]], DecodingResults]]: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using fast beam search, num_path are selected and rescored using a given language model. The shortest path within the @@ -1262,10 +1390,13 @@ def fast_beam_search_with_nbest_rescoring( yields more unique paths. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: Return the decoded result in a dict, where the key has the form - 'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the - ngram LM scale value used during decoding, i.e., 0.1. + 'ngram_lm_scale_xx' and the value is the decoded results + optionally with timestamps. `xx` is the ngram LM scale value + used during decoding, i.e., 0.1. """ lattice = fast_beam_search( model=model, @@ -1343,16 +1474,18 @@ def fast_beam_search_with_nbest_rescoring( log_semiring=False, ) - ans: Dict[str, List[List[int]]] = {} + ans: Dict[str, Union[List[List[int]], DecodingResults]] = {} for s in ngram_lm_scale_list: key = f"ngram_lm_scale_{s}" tot_scores = am_scores.values + s * ngram_lm_scores ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) max_indexes = ragged_tot_scores.argmax() best_path = k2.index_fsa(nbest.fsa, max_indexes) - hyps = get_texts(best_path) - ans[key] = hyps + if not return_timestamps: + ans[key] = get_texts(best_path) + else: + ans[key] = get_texts_with_timestamp(best_path) return ans @@ -1376,7 +1509,8 @@ def fast_beam_search_with_nbest_rnn_rescoring( use_double_scores: bool = True, nbest_scale: float = 0.5, temperature: float = 1.0, -) -> Dict[str, List[List[int]]]: + return_timestamps: bool = False, +) -> Dict[str, Union[List[List[int]], DecodingResults]]: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using fast beam search, num_path are selected and rescored using a given language model and a rnn-lm. @@ -1422,10 +1556,13 @@ def fast_beam_search_with_nbest_rnn_rescoring( yields more unique paths. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: Return the decoded result in a dict, where the key has the form - 'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the - ngram LM scale value used during decoding, i.e., 0.1. + 'ngram_lm_scale_xx' and the value is the decoded results + optionally with timestamps. `xx` is the ngram LM scale value + used during decoding, i.e., 0.1. """ lattice = fast_beam_search( model=model, @@ -1537,9 +1674,11 @@ def fast_beam_search_with_nbest_rnn_rescoring( ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) max_indexes = ragged_tot_scores.argmax() best_path = k2.index_fsa(nbest.fsa, max_indexes) - hyps = get_texts(best_path) - ans[key] = hyps + if not return_timestamps: + ans[key] = get_texts(best_path) + else: + ans[key] = get_texts_with_timestamp(best_path) return ans diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 873892bb9..13697008f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -106,6 +106,22 @@ Usage: --beam 20.0 \ --max-contexts 8 \ --max-states 64 + +To evaluate symbol delay, you should: +(1) Generate cuts with word-time alignments: +./local/add_alignment_librispeech.py \ + --alignments-dir data/alignment \ + --cuts-in-dir data/fbank \ + --cuts-out-dir data/fbank_ali +(2) Set the argument "--manifest-dir data/fbank_ali" while decoding. +For example: +./pruned_transducer_stateless4/decode.py \ + --epoch 40 \ + --avg 20 \ + --exp-dir ./pruned_transducer_stateless4/exp \ + --max-duration 600 \ + --decoding-method greedy_search \ + --manifest-dir data/fbank_ali """ @@ -142,10 +158,12 @@ from icefall.checkpoint import ( from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, + DecodingResults, + parse_hyp_and_timestamp, setup_logger, - store_transcripts, + store_transcripts_and_timestamps, str2bool, - write_error_stats, + write_error_stats_with_timestamps, ) LOG_EPS = math.log(1e-10) @@ -318,7 +336,7 @@ def get_parser(): "--left-context", type=int, default=64, - help="left context can be seen during decoding (in frames after subsampling)", + help="left context can be seen during decoding (in frames after subsampling)", # noqa ) parser.add_argument( @@ -350,7 +368,7 @@ def decode_one_batch( batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: +) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -358,9 +376,10 @@ def decode_one_batch( if greedy_search is used, it would be "greedy_search" If beam search with a beam size of 7 is used, it would be "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. + - value: It is a tuple. `len(value[0])` and `len(value[1])` are both + equal to the batch size. `value[0][i]` and `value[1][i]` + are the decoding result and timestamps for the i-th utterance + in the given batch respectively. Args: params: It's the return value of :func:`get_params`. @@ -379,8 +398,8 @@ def decode_one_batch( only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: - Return the decoding result. See above description for the format of - the returned dict. + Return the decoding result and timestamps. See above description for the + format of the returned dict. """ device = next(model.parameters()).device feature = batch["inputs"] @@ -412,10 +431,8 @@ def decode_one_batch( x=feature, x_lens=feature_lens ) - hyps = [] - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( + res = fast_beam_search_one_best( model=model, decoding_graph=decoding_graph, encoder_out=encoder_out, @@ -423,11 +440,10 @@ def decode_one_batch( beam=params.beam, max_contexts=params.max_contexts, max_states=params.max_states, + return_timestamps=True, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( + res = fast_beam_search_nbest_LG( model=model, decoding_graph=decoding_graph, encoder_out=encoder_out, @@ -437,11 +453,10 @@ def decode_one_batch( max_states=params.max_states, num_paths=params.num_paths, nbest_scale=params.nbest_scale, + return_timestamps=True, ) - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( + res = fast_beam_search_nbest( model=model, decoding_graph=decoding_graph, encoder_out=encoder_out, @@ -451,11 +466,10 @@ def decode_one_batch( max_states=params.max_states, num_paths=params.num_paths, nbest_scale=params.nbest_scale, + return_timestamps=True, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( + res = fast_beam_search_nbest_oracle( model=model, decoding_graph=decoding_graph, encoder_out=encoder_out, @@ -466,56 +480,67 @@ def decode_one_batch( num_paths=params.num_paths, ref_texts=sp.encode(supervisions["text"]), nbest_scale=params.nbest_scale, + return_timestamps=True, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) elif ( params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1 ): - hyp_tokens = greedy_search_batch( + res = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, + return_timestamps=True, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( + res = modified_beam_search( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, + return_timestamps=True, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) else: batch_size = encoder_out.size(0) - + tokens = [] + timestamps = [] for i in range(batch_size): # fmt: off encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] # fmt: on if params.decoding_method == "greedy_search": - hyp = greedy_search( + res = greedy_search( model=model, encoder_out=encoder_out_i, max_sym_per_frame=params.max_sym_per_frame, + return_timestamps=True, ) elif params.decoding_method == "beam_search": - hyp = beam_search( + res = beam_search( model=model, encoder_out=encoder_out_i, beam=params.beam_size, + return_timestamps=True, ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" ) - hyps.append(sp.decode(hyp).split()) + tokens.extend(res.tokens) + timestamps.extend(res.timestamps) + res = DecodingResults(tokens=tokens, timestamps=timestamps) + + hyps, timestamps = parse_hyp_and_timestamp( + decoding_method=params.decoding_method, + res=res, + sp=sp, + subsampling_factor=params.subsampling_factor, + frame_shift_ms=params.frame_shift_ms, + word_table=word_table, + ) if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} + return {"greedy_search": (hyps, timestamps)} elif "fast_beam_search" in params.decoding_method: key = f"beam_{params.beam}_" key += f"max_contexts_{params.max_contexts}_" @@ -526,9 +551,9 @@ def decode_one_batch( if "LG" in params.decoding_method: key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - return {key: hyps} + return {key: (hyps, timestamps)} else: - return {f"beam_size_{params.beam_size}": hyps} + return {f"beam_size_{params.beam_size}": (hyps, timestamps)} def decode_dataset( @@ -538,7 +563,9 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: +) -> Dict[ + str, List[Tuple[str, List[str], List[str], List[float], List[float]]] +]: """Decode dataset. Args: @@ -559,9 +586,12 @@ def decode_dataset( Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. + Its value is a list of tuples. Each tuple contains five elements: + - cut_id + - reference transcript + - predicted result + - timestamp of reference transcript + - timestamp of predicted result """ num_cuts = 0 @@ -580,6 +610,18 @@ def decode_dataset( texts = batch["supervisions"]["text"] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + timestamps_ref = [] + for cut in batch["supervisions"]["cut"]: + for s in cut.supervisions: + time = [] + if s.alignment is not None and "word" in s.alignment: + time = [ + aliword.start + for aliword in s.alignment["word"] + if aliword.symbol != "" + ] + timestamps_ref.append(time) + hyps_dict = decode_one_batch( params=params, model=model, @@ -589,12 +631,18 @@ def decode_dataset( batch=batch, ) - for name, hyps in hyps_dict.items(): + for name, (hyps, timestamps_hyp) in hyps_dict.items(): this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + assert len(hyps) == len(texts) and len(timestamps_hyp) == len( + timestamps_ref + ) + for cut_id, hyp_words, ref_text, time_hyp, time_ref in zip( + cut_ids, hyps, texts, timestamps_hyp, timestamps_ref + ): ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) + this_batch.append( + (cut_id, ref_words, hyp_words, time_ref, time_hyp) + ) results[name].extend(this_batch) @@ -612,15 +660,19 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], + results_dict: Dict[ + str, + List[Tuple[List[str], List[str], List[str], List[float], List[float]]], + ], ): test_set_wers = dict() + test_set_delays = dict() for key, results in results_dict.items(): recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) results = sorted(results) - store_transcripts(filename=recog_path, texts=results) + store_transcripts_and_timestamps(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned @@ -629,10 +681,11 @@ def save_results( params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: - wer = write_error_stats( + wer, mean_delay, var_delay = write_error_stats_with_timestamps( f, f"{test_set_name}-{key}", results, enable_log=True ) test_set_wers[key] = wer + test_set_delays[key] = (mean_delay, var_delay) logging.info("Wrote detailed error stats to {}".format(errs_filename)) @@ -646,6 +699,19 @@ def save_results( for key, val in test_set_wers: print("{}\t{}".format(key, val), file=f) + test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0]) + delays_info = ( + params.res_dir + / f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(delays_info, "w") as f: + print("settings\tsymbol-delay", file=f) + for key, val in test_set_delays: + print( + "{}\tmean: {}s, variance: {}".format(key, val[0], val[1]), + file=f, + ) + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) note = "\tbest for {}".format(test_set_name) for key, val in test_set_wers: @@ -653,6 +719,15 @@ def save_results( note = "" logging.info(s) + s = "\nFor {}, symbol-delay of different settings are:\n".format( + test_set_name + ) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_delays: + s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) + note = "" + logging.info(s) + @torch.no_grad() def main(): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 13a5b1a51..4c55fd609 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -386,6 +386,7 @@ def get_params() -> AttributeDict: """ params = AttributeDict( { + "frame_shift_ms": 10.0, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, diff --git a/icefall/utils.py b/icefall/utils.py index 6c115ed16..45a49fb5c 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -24,9 +24,10 @@ import re import subprocess from collections import defaultdict from contextlib import contextmanager +from dataclasses import dataclass from datetime import datetime from pathlib import Path -from typing import Dict, Iterable, List, TextIO, Tuple, Union +from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union import k2 import k2.version @@ -248,6 +249,86 @@ def get_texts( return aux_labels.tolist() +@dataclass +class DecodingResults: + # Decoded token IDs for each utterance in the batch + tokens: List[List[int]] + + # timestamps[i][k] contains the frame number on which tokens[i][k] + # is decoded + timestamps: List[List[int]] + + # hyps[i] is the recognition results, i.e., word IDs + # for the i-th utterance with fast_beam_search_nbest_LG. + hyps: Union[List[List[int]], k2.RaggedTensor] = None + + +def get_tokens_and_timestamps(labels: List[int]) -> Tuple[List[int], List[int]]: + tokens = [] + timestamps = [] + for i, v in enumerate(labels): + if v != 0: + tokens.append(v) + timestamps.append(i) + + return tokens, timestamps + + +def get_texts_with_timestamp( + best_paths: k2.Fsa, return_ragged: bool = False +) -> DecodingResults: + """Extract the texts (as word IDs) and timestamps from the best-path FSAs. + Args: + best_paths: + A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. + containing multiple FSAs, which is expected to be the result + of k2.shortest_path (otherwise the returned values won't + be meaningful). + return_ragged: + True to return a ragged tensor with two axes [utt][word_id]. + False to return a list-of-list word IDs. + Returns: + Returns a list of lists of int, containing the label sequences we + decoded. + """ + if isinstance(best_paths.aux_labels, k2.RaggedTensor): + # remove 0's and -1's. + aux_labels = best_paths.aux_labels.remove_values_leq(0) + # TODO: change arcs.shape() to arcs.shape + aux_shape = best_paths.arcs.shape().compose(aux_labels.shape) + + # remove the states and arcs axes. + aux_shape = aux_shape.remove_axis(1) + aux_shape = aux_shape.remove_axis(1) + aux_labels = k2.RaggedTensor(aux_shape, aux_labels.values) + else: + # remove axis corresponding to states. + aux_shape = best_paths.arcs.shape().remove_axis(1) + aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels) + # remove 0's and -1's. + aux_labels = aux_labels.remove_values_leq(0) + + assert aux_labels.num_axes == 2 + + labels_shape = best_paths.arcs.shape().remove_axis(1) + labels_list = k2.RaggedTensor( + labels_shape, best_paths.labels.contiguous() + ).tolist() + + tokens = [] + timestamps = [] + for labels in labels_list: + token, time = get_tokens_and_timestamps(labels[:-1]) + tokens.append(token) + timestamps.append(time) + + return DecodingResults( + tokens=tokens, + timestamps=timestamps, + hyps=aux_labels if return_ragged else aux_labels.tolist(), + ) + + def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]: """Extract labels or aux_labels from the best-path FSAs. @@ -352,6 +433,33 @@ def store_transcripts( print(f"{cut_id}:\thyp={hyp}", file=f) +def store_transcripts_and_timestamps( + filename: Pathlike, + texts: Iterable[Tuple[str, List[str], List[str], List[float], List[float]]], +) -> None: + """Save predicted results and reference transcripts as well as their timestamps + to a file. + + Args: + filename: + File to save the results to. + texts: + An iterable of tuples. The first element is the cur_id, the second is + the reference transcript and the third element is the predicted result. + Returns: + Return None. + """ + with open(filename, "w") as f: + for cut_id, ref, hyp, time_ref, time_hyp in texts: + print(f"{cut_id}:\tref={ref}", file=f) + print(f"{cut_id}:\thyp={hyp}", file=f) + if len(time_ref) > 0: + s = "[" + ", ".join(["%0.3f" % i for i in time_ref]) + "]" + print(f"{cut_id}:\ttimestamp_ref={s}", file=f) + s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]" + print(f"{cut_id}:\ttimestamp_hyp={s}", file=f) + + def write_error_stats( f: TextIO, test_set_name: str, @@ -519,6 +627,211 @@ def write_error_stats( return float(tot_err_rate) +def write_error_stats_with_timestamps( + f: TextIO, + test_set_name: str, + results: List[Tuple[str, List[str], List[str], List[float], List[float]]], + enable_log: bool = True, +) -> Tuple[float, float, float]: + """Write statistics based on predicted results and reference transcripts + as well as their timestamps. + + It will write the following to the given file: + + - WER + - number of insertions, deletions, substitutions, corrects and total + reference words. For example:: + + Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606 + reference words (2337 correct) + + - The difference between the reference transcript and predicted result. + An instance is given below:: + + THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES + + The above example shows that the reference word is `EDISON`, + but it is predicted to `ADDISON` (a substitution error). + + Another example is:: + + FOR THE FIRST DAY (SIR->*) I THINK + + The reference word `SIR` is missing in the predicted + results (a deletion error). + results: + An iterable of tuples. The first element is the cur_id, the second is + the reference transcript and the third element is the predicted result. + enable_log: + If True, also print detailed WER to the console. + Otherwise, it is written only to the given file. + + Returns: + Return total word error rate and mean delay. + """ + subs: Dict[Tuple[str, str], int] = defaultdict(int) + ins: Dict[str, int] = defaultdict(int) + dels: Dict[str, int] = defaultdict(int) + + # `words` stores counts per word, as follows: + # corr, ref_sub, hyp_sub, ins, dels + words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0]) + num_corr = 0 + ERR = "*" + # Compute mean alignment delay on the correct words + all_delay = [] + for cut_id, ref, hyp, time_ref, time_hyp in results: + ali = kaldialign.align(ref, hyp, ERR) + has_time_ref = len(time_ref) > 0 + if has_time_ref: + # pointer to timestamp_hyp + p_hyp = 0 + # pointer to timestamp_ref + p_ref = 0 + for ref_word, hyp_word in ali: + if ref_word == ERR: + ins[hyp_word] += 1 + words[hyp_word][3] += 1 + if has_time_ref: + p_hyp += 1 + elif hyp_word == ERR: + dels[ref_word] += 1 + words[ref_word][4] += 1 + if has_time_ref: + p_ref += 1 + elif hyp_word != ref_word: + subs[(ref_word, hyp_word)] += 1 + words[ref_word][1] += 1 + words[hyp_word][2] += 1 + if has_time_ref: + p_hyp += 1 + p_ref += 1 + else: + words[ref_word][0] += 1 + num_corr += 1 + if has_time_ref: + all_delay.append(time_hyp[p_hyp] - time_ref[p_ref]) + p_hyp += 1 + p_ref += 1 + if has_time_ref: + assert p_hyp == len(hyp), (p_hyp, len(hyp)) + assert p_ref == len(ref), (p_ref, len(ref)) + + ref_len = sum([len(r) for _, r, _, _, _ in results]) + sub_errs = sum(subs.values()) + ins_errs = sum(ins.values()) + del_errs = sum(dels.values()) + tot_errs = sub_errs + ins_errs + del_errs + tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len) + + mean_delay = "inf" + var_delay = "inf" + num_delay = len(all_delay) + if num_delay > 0: + mean_delay = sum(all_delay) / num_delay + var_delay = sum([(i - mean_delay) ** 2 for i in all_delay]) / num_delay + mean_delay = "%.3f" % mean_delay + var_delay = "%.3f" % var_delay + + if enable_log: + logging.info( + f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} " + f"[{tot_errs} / {ref_len}, {ins_errs} ins, " + f"{del_errs} del, {sub_errs} sub ]" + ) + logging.info( + f"[{test_set_name}] %symbol-delay mean: {mean_delay}s, variance: {var_delay} " # noqa + f"computed on {num_delay} correct words" + ) + + print(f"%WER = {tot_err_rate}", file=f) + print( + f"Errors: {ins_errs} insertions, {del_errs} deletions, " + f"{sub_errs} substitutions, over {ref_len} reference " + f"words ({num_corr} correct)", + file=f, + ) + print( + "Search below for sections starting with PER-UTT DETAILS:, " + "SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:", + file=f, + ) + + print("", file=f) + print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f) + for cut_id, ref, hyp, _, _ in results: + ali = kaldialign.align(ref, hyp, ERR) + combine_successive_errors = True + if combine_successive_errors: + ali = [[[x], [y]] for x, y in ali] + for i in range(len(ali) - 1): + if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]: + ali[i + 1][0] = ali[i][0] + ali[i + 1][0] + ali[i + 1][1] = ali[i][1] + ali[i + 1][1] + ali[i] = [[], []] + ali = [ + [ + list(filter(lambda a: a != ERR, x)), + list(filter(lambda a: a != ERR, y)), + ] + for x, y in ali + ] + ali = list(filter(lambda x: x != [[], []], ali)) + ali = [ + [ + ERR if x == [] else " ".join(x), + ERR if y == [] else " ".join(y), + ] + for x, y in ali + ] + + print( + f"{cut_id}:\t" + + " ".join( + ( + ref_word + if ref_word == hyp_word + else f"({ref_word}->{hyp_word})" + for ref_word, hyp_word in ali + ) + ), + file=f, + ) + + print("", file=f) + print("SUBSTITUTIONS: count ref -> hyp", file=f) + + for count, (ref, hyp) in sorted( + [(v, k) for k, v in subs.items()], reverse=True + ): + print(f"{count} {ref} -> {hyp}", file=f) + + print("", file=f) + print("DELETIONS: count ref", file=f) + for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True): + print(f"{count} {ref}", file=f) + + print("", file=f) + print("INSERTIONS: count hyp", file=f) + for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True): + print(f"{count} {hyp}", file=f) + + print("", file=f) + print( + "PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f + ) + for _, word, counts in sorted( + [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True + ): + (corr, ref_sub, hyp_sub, ins, dels) = counts + tot_errs = ref_sub + hyp_sub + ins + dels + ref_count = corr + ref_sub + dels + hyp_count = corr + hyp_sub + ins + + print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) + return float(tot_err_rate), float(mean_delay), float(var_delay) + + class MetricsTracker(collections.defaultdict): def __init__(self): # Passing the type 'int' to the base-class constructor @@ -978,6 +1291,137 @@ def display_and_save_batch( logging.info(f"num tokens: {num_tokens}") +def convert_timestamp( + frames: List[int], + subsampling_factor: int, + frame_shift_ms: float = 10, +) -> List[float]: + """Convert frame numbers to time (in seconds) given subsampling factor + and frame shift (in milliseconds). + + Args: + frames: + A list of frame numbers after subsampling. + subsampling_factor: + The subsampling factor of the model. + frame_shift_ms: + Frame shift in milliseconds between two contiguous frames. + Return: + Return the time in seconds corresponding to each given frame. + """ + frame_shift = frame_shift_ms / 1000.0 + time = [] + for f in frames: + time.append(f * subsampling_factor * frame_shift) + + return time + + +def parse_timestamp(tokens: List[str], timestamp: List[float]) -> List[float]: + """ + Parse timestamp of each word. + + Args: + tokens: + List of tokens. + timestamp: + List of timestamp of each token. + + Returns: + List of timestamp of each word. + """ + start_token = b"\xe2\x96\x81".decode() # '_' + assert len(tokens) == len(timestamp) + ans = [] + for i in range(len(tokens)): + flag = False + if i == 0 or tokens[i].startswith(start_token): + flag = True + if len(tokens[i]) == 1 and tokens[i].startswith(start_token): + # tokens[i] == start_token + if i == len(tokens) - 1: + # it is the last token + flag = False + elif tokens[i + 1].startswith(start_token): + # the next token also starts with start_token + flag = False + if flag: + ans.append(timestamp[i]) + return ans + + +def parse_hyp_and_timestamp( + res: DecodingResults, + decoding_method: str, + sp: spm.SentencePieceProcessor, + subsampling_factor: int, + frame_shift_ms: float = 10, + word_table: Optional[k2.SymbolTable] = None, +) -> Tuple[List[List[str]], List[List[float]]]: + """Parse hypothesis and timestamp. + + Args: + res: + A DecodingResults object. + decoding_method: + Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + sp: + The BPE model. + subsampling_factor: + The integer subsampling factor. + frame_shift_ms: + The float frame shift used for feature extraction. + word_table: + The word symbol table. + + Returns: + Return a list of hypothesis and timestamp. + """ + assert decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + + hyps = [] + timestamps = [] + + N = len(res.tokens) + assert len(res.timestamps) == N + use_word_table = False + if decoding_method == "fast_beam_search_nbest_LG": + assert word_table is not None + use_word_table = True + + for i in range(N): + tokens = sp.id_to_piece(res.tokens[i]) + if use_word_table: + words = [word_table[i] for i in res.hyps[i]] + else: + words = sp.decode_pieces(tokens).split() + time = convert_timestamp( + res.timestamps[i], subsampling_factor, frame_shift_ms + ) + time = parse_timestamp(tokens, time) + assert len(time) == len(words), (tokens, words) + + hyps.append(words) + timestamps.append(time) + + return hyps, timestamps + + # `is_module_available` is copied from # https://github.com/pytorch/audio/blob/6bad3a66a7a1c7cc05755e9ee5931b7391d2b94c/torchaudio/_internal/module_utils.py#L9 def is_module_available(*modules: str) -> bool: From d389524d457a4ebd9a3f216dfb4f4ded4eceb07f Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Tue, 1 Nov 2022 11:09:56 +0800 Subject: [PATCH 11/11] remove tail padding for non-streaming models (#625) --- .../ASR/pruned_transducer_stateless2/decode.py | 13 ++++++------- .../ASR/pruned_transducer_stateless3/decode.py | 13 ++++++------- .../ASR/pruned_transducer_stateless4/decode.py | 13 ++++++------- .../ASR/pruned_transducer_stateless5/decode.py | 13 ++++++------- 4 files changed, 24 insertions(+), 28 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 7852dafc9..3b834b919 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -380,14 +380,13 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) - if params.simulate_streaming: + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index fa4f1e7d9..0f30792e3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -462,14 +462,13 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) - if params.simulate_streaming: + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 13697008f..85097a01a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -411,14 +411,13 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) - if params.simulate_streaming: + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 96103500b..632932214 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -378,14 +378,13 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) - if params.simulate_streaming: + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens,