From b2b4d9e0b60b65d6ecd40cc252d60930784bb830 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Mon, 21 Mar 2022 16:22:25 +0800 Subject: [PATCH 01/16] Add fast beam search decoding (#250) * Add fast beam search decoding * Minor fixes * Minor fixes * Minor fixes * Fix comments * Fix comments --- .../beam_search.py | 82 ++++++++++ .../ASR/pruned_transducer_stateless/decode.py | 149 ++++++++++++++---- .../pruned_transducer_stateless/decoder.py | 1 + 3 files changed, 203 insertions(+), 29 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index 38ab16507..651854999 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -17,9 +17,91 @@ from dataclasses import dataclass from typing import Dict, List, Optional +import k2 import torch from model import Transducer +from icefall.decode import one_best_decoding +from icefall.utils import get_texts + + +def fast_beam_search( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, +) -> List[List[int]]: + """It limits the maximum number of symbols per frame to 1. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a HLG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + Returns: + Return the decoded result. + """ + assert encoder_out.ndim == 3 + + context_size = model.decoder.context_size + vocab_size = model.decoder.vocab_size + + B, T, C = encoder_out.shape + + config = k2.RnntDecodingConfig( + vocab_size=vocab_size, + decoder_history_len=context_size, + beam=beam, + max_contexts=max_contexts, + max_states=max_states, + ) + individual_streams = [] + for i in range(B): + individual_streams.append(k2.RnntDecodingStream(decoding_graph)) + decoding_streams = k2.RnntDecodingStreams(individual_streams, config) + + for t in range(T): + # shape is a RaggedShape of shape (B, context) + # contexts is a Tensor of shape (shape.NumElements(), context_size) + shape, contexts = decoding_streams.get_contexts() + # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts = contexts.to(torch.int64) + # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) + decoder_out = model.decoder(contexts, need_pad=False) + # current_encoder_out is of shape + # (shape.NumElements(), 1, encoder_out_dim) + # fmt: off + current_encoder_out = torch.index_select( + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1) + ) + # fmt: on + logits = model.joiner( + current_encoder_out.unsqueeze(2), decoder_out.unsqueeze(1) + ) + logits = logits.squeeze(1).squeeze(1) + log_probs = logits.log_softmax(dim=-1) + decoding_streams.advance(log_probs) + decoding_streams.terminate_and_flush_to_streams() + lattice = decoding_streams.format_output(encoder_out_lens.tolist()) + + best_path = one_best_decoding(lattice) + hyps = get_texts(best_path) + return hyps + def greedy_search( model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index fedf663b8..ad76411c0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -42,6 +42,17 @@ Usage: --max-duration 100 \ --decoding-method modified_beam_search \ --beam-size 4 + +(4) fast beam search +./pruned_transducer_stateless/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless/exp \ + --max-duration 1500 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 """ @@ -49,13 +60,19 @@ import argparse import logging from collections import defaultdict from pathlib import Path -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple +import k2 import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from beam_search import beam_search, greedy_search, modified_beam_search +from beam_search import ( + beam_search, + fast_beam_search, + greedy_search, + modified_beam_search, +) from train import get_params, get_transducer_model from icefall.checkpoint import ( @@ -125,6 +142,7 @@ def get_parser(): - greedy_search - beam_search - modified_beam_search + - fast_beam_search """, ) @@ -132,8 +150,35 @@ def get_parser(): "--beam-size", type=int, default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, help="""Used only when --decoding-method is - beam_search or modified_beam_search""", + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search""", ) parser.add_argument( @@ -159,6 +204,7 @@ def decode_one_batch( model: nn.Module, sp: spm.SentencePieceProcessor, batch: dict, + decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -181,6 +227,9 @@ def decode_one_batch( It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -199,36 +248,62 @@ def decode_one_batch( x=feature, x_lens=feature_lens ) hyps = [] - batch_size = encoder_out.size(0) - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - elif params.decoding_method == "modified_beam_search": - hyp = modified_beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + elif params.decoding_method == "modified_beam_search": + hyp = modified_beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) if params.decoding_method == "greedy_search": return {"greedy_search": hyps} + elif params.decoding_method == "fast_beam_search": + return { + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps + } else: - return {f"beam_{params.beam_size}": hyps} + return {f"beam_size_{params.beam_size}": hyps} def decode_dataset( @@ -236,6 +311,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -248,6 +324,9 @@ def decode_dataset( The neural model. sp: The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -275,6 +354,7 @@ def decode_dataset( params=params, model=model, sp=sp, + decoding_graph=decoding_graph, batch=batch, ) @@ -355,12 +435,17 @@ def main(): assert params.decoding_method in ( "greedy_search", "beam_search", + "fast_beam_search", "modified_beam_search", ) params.res_dir = params.exp_dir / params.decoding_method params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if "beam_search" in params.decoding_method: + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + elif "beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" @@ -408,6 +493,11 @@ def main(): model.eval() model.device = device + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -428,6 +518,7 @@ def main(): params=params, model=model, sp=sp, + decoding_graph=decoding_graph, ) save_results( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py index 3d4e69a4b..8c728fdc5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py @@ -61,6 +61,7 @@ class Decoder(nn.Module): assert context_size >= 1, context_size self.context_size = context_size + self.vocab_size = vocab_size if context_size > 1: self.conv = nn.Conv1d( in_channels=embedding_dim, From d5c78a2238a4b6baeddf2f741c23cd0de11f5d2f Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 22 Mar 2022 10:32:22 +0800 Subject: [PATCH 02/16] Implement greedy search in batch mode for transducer decoding. (#262) --- .../beam_search.py | 60 ++++++++++++++++++- .../ASR/pruned_transducer_stateless/decode.py | 11 ++++ 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index 651854999..05b027214 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -106,7 +106,7 @@ def fast_beam_search( def greedy_search( model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int ) -> List[int]: - """ + """Greedy search for a single utterance. Args: model: An instance of `Transducer`. @@ -178,6 +178,64 @@ def greedy_search( return hyp +def greedy_search_batch( + model: Transducer, encoder_out: torch.Tensor +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C), where N >= 1. + Returns: + Return a list-of-list integers containing the decoded results. + len(ans) equals to encoder_out.size(0). + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + device = model.device + + batch_size = encoder_out.size(0) + T = encoder_out.size(1) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + + hyps = [[blank_id] * context_size for _ in range(batch_size)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (batch_size, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + # decoder_out: (batch_size, 1, decoder_out_dim) + for t in range(T): + current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa + # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) + logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1)) + # logits'shape (batch_size, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps] + decoder_input = torch.tensor(decoder_input, device=device) + decoder_out = model.decoder(decoder_input, need_pad=False) + + ans = [h[context_size:] for h in hyps] + return ans + + @dataclass class Hypothesis: # The predicted tokens so far. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index ad76411c0..c43af9741 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -71,6 +71,7 @@ from beam_search import ( beam_search, fast_beam_search, greedy_search, + greedy_search_batch, modified_beam_search, ) from train import get_params, get_transducer_model @@ -261,6 +262,16 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) else: batch_size = encoder_out.size(0) From 8c7995d493c4309c3d09bdabfa1ab12b4eec2657 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 22 Mar 2022 15:14:04 +0800 Subject: [PATCH 03/16] Support modified beam search in batch mode. (#264) * Support modified beam search in batch mode. * Update k2 versions in GitHub CI. --- .../beam_search.py | 145 +++++++++++++++++- .../ASR/pruned_transducer_stateless/decode.py | 14 +- .../pruned_transducer_stateless/pretrained.py | 64 +++++--- requirements-ci.txt | 2 +- 4 files changed, 195 insertions(+), 30 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index 05b027214..49004d2ba 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -188,7 +188,7 @@ def greedy_search_batch( encoder_out: Output from the encoder. Its shape is (N, T, C), where N >= 1. Returns: - Return a list-of-list integers containing the decoded results. + Return a list-of-list of token IDs containing the decoded results. len(ans) equals to encoder_out.size(0). """ assert encoder_out.ndim == 3 @@ -362,13 +362,156 @@ class HypothesisList(object): return ", ".join(s) +def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: + """Return a ragged shape with axes [utt][num_hyps]. + + Args: + hyps: + len(hyps) == batch_size. It contains the current hypothesis for + each utterance in the batch. + Returns: + Return a ragged shape with 2 axes [utt][num_hyps]. Note that + the shape is on CPU. + """ + num_hyps = [len(h) for h in hyps] + + # torch.cumsum() is inclusive sum, so we put a 0 at the beginning + # to get exclusive sum later. + num_hyps.insert(0, 0) + + num_hyps = torch.tensor(num_hyps) + row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32) + ans = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=row_splits[-1].item() + ) + return ans + + def modified_beam_search( model: Transducer, encoder_out: torch.Tensor, beam: int = 4, +) -> List[List[int]]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + beam: + Number of active paths during the beam search. + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + + batch_size = encoder_out.size(0) + T = encoder_out.size(1) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = model.device + B = [HypothesisList() for _ in range(batch_size)] + for i in range(batch_size): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + + for t in range(T): + current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + + hyps_shape = _get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + # decoder_output is of shape (num_hyps, 1, 1, decoder_output_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + if new_token != blank_id: + new_ys.append(new_token) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B[i].add(new_hyp) + + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + ans = [h.ys[context_size:] for h in best_hyps] + + return ans + + +def _deprecated_modified_beam_search( + model: Transducer, + encoder_out: torch.Tensor, + beam: int = 4, ) -> List[int]: """It limits the maximum number of symbols per frame to 1. + It decodes only one utterance at a time. We keep it only for reference. + The function :func:`modified_beam_search` should be preferred as it + supports batch decoding. + + Args: model: An instance of `Transducer`. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index c43af9741..811e74ad7 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -272,6 +272,14 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) else: batch_size = encoder_out.size(0) @@ -291,12 +299,6 @@ def decode_one_batch( encoder_out=encoder_out_i, beam=params.beam_size, ) - elif params.decoding_method == "modified_beam_search": - hyp = modified_beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py index e6528b8d7..75b889d7c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py @@ -50,7 +50,12 @@ import kaldifeat import sentencepiece as spm import torch import torchaudio -from beam_search import beam_search, greedy_search, modified_beam_search +from beam_search import ( + beam_search, + greedy_search, + greedy_search_batch, + modified_beam_search, +) from torch.nn.utils.rnn import pad_sequence from train import get_params, get_transducer_model @@ -224,28 +229,43 @@ def main(): if params.method == "beam_search": msg += f" with beam size {params.beam_size}" logging.info(msg) - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - elif params.method == "modified_beam_search": - hyp = modified_beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - else: - raise ValueError(f"Unsupported method: {params.method}") + if params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) - hyps.append(sp.decode(hyp).split()) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) s = "\n" for filename, hyp in zip(params.sound_files, hyps): diff --git a/requirements-ci.txt b/requirements-ci.txt index b5ee6b51c..7fb4b1665 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -11,7 +11,7 @@ graphviz==0.19.1 -f https://download.pytorch.org/whl/cpu/torch_stable.html torch==1.10.0+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html torchaudio==0.10.0+cpu --f https://k2-fsa.org/nightly/ k2==1.9.dev20211101+cpu.torch1.10.0 +-f https://k2-fsa.org/nightly/ k2==1.14.dev20220316+cpu.torch1.10.0 git+https://github.com/lhotse-speech/lhotse kaldilm==1.11 From 6a091da0b0543befb0492848d3583700c274d111 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 23 Mar 2022 12:22:05 +0800 Subject: [PATCH 04/16] Minor fixes for saving checkpoints. (#265) * Minor fixes for saving checkpoints. * Fix loading checkpoints saved by previous code. --- egs/librispeech/ASR/pruned_transducer_stateless/train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index e71f0d1c6..0f51b4382 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -395,9 +395,10 @@ def load_checkpoint_if_available( "cur_batch_idx", ] for k in keys: - params[k] = saved_params[k] + params[k] = saved_params.get(k, 0) - params["start_epoch"] = saved_params["cur_epoch"] + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] return saved_params From 3ae726573713980e78aaf55d170d06e5e747cc10 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 23 Mar 2022 14:37:54 +0800 Subject: [PATCH 05/16] More fixes to the checkpoint code. (#266) --- .../ASR/pruned_transducer_stateless/train.py | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index 0f51b4382..1f52370fd 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -392,13 +392,16 @@ def load_checkpoint_if_available( "batch_idx_train", "best_train_loss", "best_valid_loss", - "cur_batch_idx", ] for k in keys: - params[k] = saved_params.get(k, 0) + params[k] = saved_params[k] - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] return saved_params @@ -784,6 +787,13 @@ def run(rank, world_size, args): def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold return 1.0 <= c.duration <= 20.0 num_in_total = len(train_cuts) @@ -798,7 +808,9 @@ def run(rank, world_size, args): logging.info(f"After removing short and long utterances: {num_left}") logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") - if checkpoints and "sampler" in checkpoints: + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch sampler_state_dict = checkpoints["sampler"] else: sampler_state_dict = None From 395a3f952be1449cd7c92b896f4eb9a1c899e2c7 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 23 Mar 2022 19:11:34 +0800 Subject: [PATCH 06/16] Batch decoding for models trained with optimized_transducer (#267) * Add greedy search in batch mode. * Add modified beam search in batch mode. --- .../pretrained.py | 121 +++------ .../pretrained.py | 119 +++------ .../beam_search.py | 6 +- .../ASR/pruned_transducer_stateless/decode.py | 2 +- .../pruned_transducer_stateless/pretrained.py | 2 +- .../ASR/transducer_stateless/beam_search.py | 233 +++++++++++++++++- .../ASR/transducer_stateless/decode.py | 147 ++++------- .../ASR/transducer_stateless/pretrained.py | 145 ++++------- .../decode.py | 146 ++++------- .../pretrained.py | 146 ++++------- 10 files changed, 494 insertions(+), 573 deletions(-) diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py index 31bab122c..9e6ed96b1 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py @@ -55,18 +55,17 @@ from typing import List import kaldifeat import torch -import torch.nn as nn import torchaudio -from beam_search import beam_search, greedy_search, modified_beam_search -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from model import Transducer +from beam_search import ( + beam_search, + greedy_search, + greedy_search_batch, + modified_beam_search, +) from torch.nn.utils.rnn import pad_sequence +from train import get_params, get_transducer_model -from icefall.env import get_env_info from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict def get_parser(): @@ -111,6 +110,13 @@ def get_parser(): "The sample rate has to be 16kHz.", ) + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + parser.add_argument( "--beam-size", type=int, @@ -137,70 +143,6 @@ def get_parser(): return parser -def get_params() -> AttributeDict: - params = AttributeDict( - { - # parameters for conformer - "feature_dim": 80, - "encoder_out_dim": 512, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - "env_info": get_env_info(), - "sample_rate": 16000, - } - ) - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.encoder_out_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - input_dim=params.encoder_out_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - return model - - def read_sound_files( filenames: List[str], expected_sample_rate: float ) -> List[torch.Tensor]: @@ -225,6 +167,7 @@ def read_sound_files( return ans +@torch.no_grad() def main(): parser = get_parser() args = parser.parse_args() @@ -249,7 +192,7 @@ def main(): model = get_transducer_model(params) checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"]) + model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() model.device = device @@ -279,12 +222,22 @@ def main(): features, batch_first=True, padding_value=math.log(1e-10) ) - hyps = [] - with torch.no_grad(): - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lens + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lens + ) + hyp_list = [] + if params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_list = greedy_search_batch( + model=model, + encoder_out=encoder_out, ) - + elif params.method == "modified_beam_search": + hyp_list = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) + else: for i in range(encoder_out.size(0)): # fmt: off encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] @@ -301,17 +254,15 @@ def main(): encoder_out=encoder_out_i, beam=params.beam_size, ) - elif params.method == "modified_beam_search": - hyp = modified_beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) else: raise ValueError( f"Unsupported decoding method: {params.method}" ) - hyps.append([lexicon.token_table[i] for i in hyp]) + hyp_list.append(hyp) + + hyps = [] + for hyp in hyp_list: + hyps.append([lexicon.token_table[i] for i in hyp]) s = "\n" for filename, hyp in zip(params.sound_files, hyps): diff --git a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py index 698594e92..f7c5b24ba 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py @@ -55,18 +55,17 @@ from typing import List import kaldifeat import torch -import torch.nn as nn import torchaudio -from beam_search import beam_search, greedy_search, modified_beam_search -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from model import Transducer +from beam_search import ( + beam_search, + greedy_search, + greedy_search_batch, + modified_beam_search, +) from torch.nn.utils.rnn import pad_sequence +from train import get_params, get_transducer_model -from icefall.env import get_env_info from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict def get_parser(): @@ -111,6 +110,13 @@ def get_parser(): "The sample rate has to be 16kHz.", ) + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + parser.add_argument( "--beam-size", type=int, @@ -137,70 +143,6 @@ def get_parser(): return parser -def get_params() -> AttributeDict: - params = AttributeDict( - { - # parameters for conformer - "feature_dim": 80, - "encoder_out_dim": 512, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - "env_info": get_env_info(), - "sample_rate": 16000, - } - ) - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.encoder_out_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - input_dim=params.encoder_out_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - return model - - def read_sound_files( filenames: List[str], expected_sample_rate: float ) -> List[torch.Tensor]: @@ -225,6 +167,7 @@ def read_sound_files( return ans +@torch.no_grad() def main(): parser = get_parser() args = parser.parse_args() @@ -279,12 +222,22 @@ def main(): features, batch_first=True, padding_value=math.log(1e-10) ) - hyps = [] - with torch.no_grad(): - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lens + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lens + ) + hyp_list = [] + if params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_list = greedy_search_batch( + model=model, + encoder_out=encoder_out, ) - + elif params.method == "modified_beam_search": + hyp_list = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) + else: for i in range(encoder_out.size(0)): # fmt: off encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] @@ -301,17 +254,15 @@ def main(): encoder_out=encoder_out_i, beam=params.beam_size, ) - elif params.method == "modified_beam_search": - hyp = modified_beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) else: raise ValueError( f"Unsupported decoding method: {params.method}" ) - hyps.append([lexicon.token_table[i] for i in hyp]) + hyp_list.append(hyp) + + hyps = [] + for hyp in hyp_list: + hyps.append([lexicon.token_table[i] for i in hyp]) s = "\n" for filename, hyp in zip(params.sound_files, hyps): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index 49004d2ba..815e1c02a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -229,7 +229,11 @@ def greedy_search_batch( if emitted: # update decoder output decoder_input = [h[-context_size:] for h in hyps] - decoder_input = torch.tensor(decoder_input, device=device) + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) decoder_out = model.decoder(decoder_input, need_pad=False) ans = [h[context_size:] for h in hyps] diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 811e74ad7..8e924bf96 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -192,7 +192,7 @@ def get_parser(): parser.add_argument( "--max-sym-per-frame", type=int, - default=3, + default=1, help="""Maximum number of symbols per frame. Used only when --decoding_method is greedy_search""", ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py index 75b889d7c..b0eb4d749 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py @@ -127,7 +127,7 @@ def get_parser(): parser.add_argument( "--max-sym-per-frame", type=int, - default=3, + default=1, help="""Maximum number of symbols per frame. Used only when --method is greedy_search. """, diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index c5efb733d..7b4fac31d 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -17,6 +17,7 @@ from dataclasses import dataclass from typing import Dict, List, Optional +import k2 import torch from model import Transducer @@ -24,7 +25,7 @@ from model import Transducer def greedy_search( model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int ) -> List[int]: - """ + """Greedy search for a single utterance. Args: model: An instance of `Transducer`. @@ -80,7 +81,7 @@ def greedy_search( logits = model.joiner( current_encoder_out, decoder_out, encoder_out_len, decoder_out_len ) - # logits is (1, 1, 1, vocab_size) + # logits is (1, vocab_size) y = logits.argmax().item() if y != blank_id: @@ -101,6 +102,75 @@ def greedy_search( return hyp +def greedy_search_batch( + model: Transducer, encoder_out: torch.Tensor +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C), where N >= 1. + Returns: + Return a list-of-list of token IDs containing the decoded results. + len(ans) equals to encoder_out.size(0). + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + device = model.device + + batch_size = encoder_out.size(0) + T = encoder_out.size(1) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + + hyps = [[blank_id] * context_size for _ in range(batch_size)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (batch_size, context_size) + decoder_out = model.decoder(decoder_input, need_pad=False) + # decoder_out: (batch_size, 1, decoder_out_dim) + + encoder_out_len = torch.ones(batch_size, dtype=torch.int32) + decoder_out_len = torch.ones(batch_size, dtype=torch.int32) + + for t in range(T): + current_encoder_out = encoder_out[:, t : t + 1, :] # noqa + # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) + logits = model.joiner( + current_encoder_out, decoder_out, encoder_out_len, decoder_out_len + ) # (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) # (batch_size, context_size) + decoder_out = model.decoder( + decoder_input, + need_pad=False, + ) # (batch_size, 1, decoder_out_dim) + + ans = [h[context_size:] for h in hyps] + return ans + + @dataclass class Hypothesis: # The predicted tokens so far. @@ -252,9 +322,11 @@ def run_decoder( device = model.device - decoder_input = torch.tensor([ys[-context_size:]], device=device).reshape( - 1, context_size - ) + decoder_input = torch.tensor( + [ys[-context_size:]], + device=device, + dtype=torch.int64, + ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) decoder_cache[key] = decoder_out @@ -314,13 +386,158 @@ def run_joiner( return log_prob +def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: + """Return a ragged shape with axes [utt][num_hyps]. + + Args: + hyps: + len(hyps) == batch_size. It contains the current hypothesis for + each utterance in the batch. + Returns: + Return a ragged shape with 2 axes [utt][num_hyps]. Note that + the shape is on CPU. + """ + num_hyps = [len(h) for h in hyps] + + # torch.cumsum() is inclusive sum, so we put a 0 at the beginning + # to get exclusive sum later. + num_hyps.insert(0, 0) + + num_hyps = torch.tensor(num_hyps) + row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32) + ans = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=row_splits[-1].item() + ) + return ans + + def modified_beam_search( model: Transducer, encoder_out: torch.Tensor, beam: int = 4, +) -> List[List[int]]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcodded. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + beam: + Number of active paths during the beam search. + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + + batch_size = encoder_out.size(0) + T = encoder_out.size(1) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = model.device + B = [HypothesisList() for _ in range(batch_size)] + for i in range(batch_size): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + + encoder_out_len = torch.tensor([1]) + decoder_out_len = torch.tensor([1]) + for t in range(T): + current_encoder_out = encoder_out[:, t : t + 1, :] # noqa + # current_encoder_out's shape is: (batch_size, 1, encoder_out_dim) + + hyps_shape = _get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + # decoder_output is of shape (num_hyps, 1, decoder_output_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + encoder_out_len.expand(decoder_out.size(0)), + decoder_out_len.expand(decoder_out.size(0)), + ) + # logits is of shape (num_hyps, vocab_size) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + if new_token != blank_id: + new_ys.append(new_token) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B[i].add(new_hyp) + + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + ans = [h.ys[context_size:] for h in best_hyps] + + return ans + + +def _deprecated_modified_beam_search( + model: Transducer, + encoder_out: torch.Tensor, + beam: int = 4, ) -> List[int]: """It limits the maximum number of symbols per frame to 1. + It decodes only one utterance at a time. We keep it only for reference. + The function :func:`modified_beam_search` should be preferred as it + supports batch decoding. + Args: model: An instance of `Transducer`. @@ -341,12 +558,6 @@ def modified_beam_search( device = model.device - decoder_input = torch.tensor( - [blank_id] * context_size, device=device - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - T = encoder_out.size(1) B = HypothesisList() diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py index f23a3a300..ac66c9b49 100755 --- a/egs/librispeech/ASR/transducer_stateless/decode.py +++ b/egs/librispeech/ASR/transducer_stateless/decode.py @@ -55,14 +55,15 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from beam_search import beam_search, greedy_search, modified_beam_search -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from model import Transducer +from beam_search import ( + beam_search, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.env import get_env_info from icefall.utils import ( AttributeDict, setup_logger, @@ -135,7 +136,7 @@ def get_parser(): parser.add_argument( "--max-sym-per-frame", type=int, - default=3, + default=1, help="""Maximum number of symbols per frame. Used only when --decoding_method is greedy_search""", ) @@ -143,70 +144,6 @@ def get_parser(): return parser -def get_params() -> AttributeDict: - params = AttributeDict( - { - # parameters for conformer - "feature_dim": 80, - "encoder_out_dim": 512, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - "env_info": get_env_info(), - } - ) - return params - - -def get_encoder_model(params: AttributeDict): - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.encoder_out_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict): - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict): - joiner = Joiner( - input_dim=params.encoder_out_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict): - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - return model - - def decode_one_batch( params: AttributeDict, model: nn.Module, @@ -251,32 +188,47 @@ def decode_one_batch( encoder_out, encoder_out_lens = model.encoder( x=feature, x_lens=feature_lens ) - hyps = [] - batch_size = encoder_out.size(0) + hyp_list: List[List[int]] = [] - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - elif params.decoding_method == "modified_beam_search": - hyp = modified_beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) + if ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_list = greedy_search_batch( + model=model, + encoder_out=encoder_out, + ) + elif params.decoding_method == "modified_beam_search": + hyp_list = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) + else: + batch_size = encoder_out.size(0) + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyp_list.append(hyp) + + hyps = [sp.decode(hyp).split() for hyp in hyp_list] if params.decoding_method == "greedy_search": return {"greedy_search": hyps} @@ -487,8 +439,5 @@ def main(): logging.info("Done!") -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - if __name__ == "__main__": main() diff --git a/egs/librispeech/ASR/transducer_stateless/pretrained.py b/egs/librispeech/ASR/transducer_stateless/pretrained.py index ad8d89918..4fb5d92c5 100755 --- a/egs/librispeech/ASR/transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless/pretrained.py @@ -59,17 +59,15 @@ from typing import List import kaldifeat import sentencepiece as spm import torch -import torch.nn as nn import torchaudio -from beam_search import beam_search, greedy_search, modified_beam_search -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from model import Transducer +from beam_search import ( + beam_search, + greedy_search, + greedy_search_batch, + modified_beam_search, +) from torch.nn.utils.rnn import pad_sequence - -from icefall.env import get_env_info -from icefall.utils import AttributeDict +from train import get_params, get_transducer_model def get_parser(): @@ -115,6 +113,13 @@ def get_parser(): "The sample rate has to be 16kHz.", ) + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + parser.add_argument( "--beam-size", type=int, @@ -132,7 +137,7 @@ def get_parser(): parser.add_argument( "--max-sym-per-frame", type=int, - default=3, + default=1, help="""Maximum number of symbols per frame. Used only when --method is greedy_search. """, @@ -141,70 +146,6 @@ def get_parser(): return parser -def get_params() -> AttributeDict: - params = AttributeDict( - { - "sample_rate": 16000, - # parameters for conformer - "feature_dim": 80, - "encoder_out_dim": 512, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - "env_info": get_env_info(), - } - ) - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.encoder_out_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - input_dim=params.encoder_out_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - return model - - def read_sound_files( filenames: List[str], expected_sample_rate: float ) -> List[torch.Tensor]: @@ -294,33 +235,45 @@ def main(): ) num_waves = encoder_out.size(0) - hyps = [] + hyp_list = [] msg = f"Using {params.method}" if params.method == "beam_search": msg += f" with beam size {params.beam_size}" logging.info(msg) - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - elif params.method == "modified_beam_search": - hyp = modified_beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - else: - raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + if params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_list = greedy_search_batch( + model=model, + encoder_out=encoder_out, + ) + elif params.method == "modified_beam_search": + hyp_list = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + hyp_list.append(hyp) + + hyps = [sp.decode(hyp).split() for hyp in hyp_list] s = "\n" for filename, hyp in zip(params.sound_files, hyps): diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py index 136afe9c0..22f137d36 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py @@ -46,15 +46,16 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import AsrDataModule -from beam_search import beam_search, greedy_search, modified_beam_search -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner +from beam_search import ( + beam_search, + greedy_search, + greedy_search_batch, + modified_beam_search, +) from librispeech import LibriSpeech -from model import Transducer +from train import get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.env import get_env_info from icefall.utils import ( AttributeDict, setup_logger, @@ -127,7 +128,7 @@ def get_parser(): parser.add_argument( "--max-sym-per-frame", type=int, - default=3, + default=1, help="""Maximum number of symbols per frame. Used only when --decoding_method is greedy_search""", ) @@ -135,71 +136,6 @@ def get_parser(): return parser -def get_params() -> AttributeDict: - params = AttributeDict( - { - # parameters for conformer - "feature_dim": 80, - "encoder_out_dim": 512, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - "env_info": get_env_info(), - } - ) - return params - - -def get_encoder_model(params: AttributeDict): - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.encoder_out_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict): - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict): - joiner = Joiner( - input_dim=params.encoder_out_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict): - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - - return model - - def decode_one_batch( params: AttributeDict, model: nn.Module, @@ -244,32 +180,47 @@ def decode_one_batch( encoder_out, encoder_out_lens = model.encoder( x=feature, x_lens=feature_lens ) - hyps = [] + hyp_list = [] batch_size = encoder_out.size(0) - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - elif params.decoding_method == "modified_beam_search": - hyp = modified_beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) + if ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_list = greedy_search_batch( + model=model, + encoder_out=encoder_out, + ) + elif params.decoding_method == "modified_beam_search": + hyp_list = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) + else: + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyp_list.append(sp.decode(hyp).split()) + + hyps = [sp.decode(hyp).split() for hyp in hyp_list] if params.decoding_method == "greedy_search": return {"greedy_search": hyps} @@ -483,8 +434,5 @@ def main(): logging.info("Done!") -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - if __name__ == "__main__": main() diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py index 5ba3acea1..df9c3186f 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py @@ -59,17 +59,15 @@ from typing import List import kaldifeat import sentencepiece as spm import torch -import torch.nn as nn import torchaudio -from beam_search import beam_search, greedy_search, modified_beam_search -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from model import Transducer +from beam_search import ( + beam_search, + greedy_search, + greedy_search_batch, + modified_beam_search, +) from torch.nn.utils.rnn import pad_sequence - -from icefall.env import get_env_info -from icefall.utils import AttributeDict +from train import get_params, get_transducer_model def get_parser(): @@ -115,6 +113,13 @@ def get_parser(): "The sample rate has to be 16kHz.", ) + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + parser.add_argument( "--beam-size", type=int, @@ -132,7 +137,7 @@ def get_parser(): parser.add_argument( "--max-sym-per-frame", type=int, - default=3, + default=1, help="""Maximum number of symbols per frame. Used only when --method is greedy_search. """, @@ -141,70 +146,6 @@ def get_parser(): return parser -def get_params() -> AttributeDict: - params = AttributeDict( - { - "sample_rate": 16000, - # parameters for conformer - "feature_dim": 80, - "encoder_out_dim": 512, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - "env_info": get_env_info(), - } - ) - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.encoder_out_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - input_dim=params.encoder_out_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - return model - - def read_sound_files( filenames: List[str], expected_sample_rate: float ) -> List[torch.Tensor]: @@ -294,33 +235,46 @@ def main(): ) num_waves = encoder_out.size(0) - hyps = [] + hyp_list = [] msg = f"Using {params.method}" if params.method == "beam_search": msg += f" with beam size {params.beam_size}" logging.info(msg) - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - elif params.method == "modified_beam_search": - hyp = modified_beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - else: - raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + if params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_list = greedy_search_batch( + model=model, + encoder_out=encoder_out, + ) + elif params.method == "modified_beam_search": + hyp_list = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) + + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + hyp_list.append(hyp) + + hyps = [sp.decode(hyp).split() for hyp in hyp_list] s = "\n" for filename, hyp in zip(params.sound_files, hyps): From f686635b546baa00654f9e3caed739adf04c399e Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Wed, 30 Mar 2022 14:52:55 +0800 Subject: [PATCH 07/16] Update diagnostics (#260) * update diagnostics.py --- icefall/diagnostics.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index fa9b98fa0..08d1628ec 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -135,8 +135,13 @@ def get_diagnostics_for_dim( return "" count = sum(counts) stats = stats / count - stats, _ = torch.symeig(stats) - stats = stats.abs().sqrt() + try: + eigs, _ = torch.symeig(stats) + stats = eigs.abs().sqrt() + except: # noqa + print("Error getting eigenvalues, trying another method.") + eigs, _ = torch.eigs(stats) + stats = eigs.abs().sqrt() # sqrt so it reflects data magnitude, like stddev- not variance elif sizes_same: stats = torch.stack(stats).sum(dim=0) From 981b0640079918a43826b82acdadde68e2517bc9 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 30 Mar 2022 18:50:54 +0800 Subject: [PATCH 08/16] Update doc to clarify the installation order of dependencies. (#279) --- docs/source/installation/index.rst | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/docs/source/installation/index.rst b/docs/source/installation/index.rst index a8c3b6865..5d364dbc0 100644 --- a/docs/source/installation/index.rst +++ b/docs/source/installation/index.rst @@ -27,9 +27,21 @@ Installation ``icefall`` depends on `k2 `_ and `lhotse `_. -We recommend you to install ``k2`` first, as ``k2`` is bound to -a specific version of PyTorch after compilation. Install ``k2`` also -installs its dependency PyTorch, which can be reused by ``lhotse``. +We recommend you to use the following steps to install the dependencies. + +- (0) Install PyTorch and torchaudio +- (1) Install k2 +- (2) Install lhotse + +.. caution:: + + Installation order matters. + +(0) Install PyTorch and torchaudio +---------------------------------- + +Please refer ``_ to install PyTorch +and torchaudio. (1) Install k2 @@ -54,14 +66,15 @@ to install ``k2``. Please refer to ``_ to install ``lhotse``. -.. HINT:: - Install ``lhotse`` also installs its dependency `torchaudio `_. +.. hint:: -.. CAUTION:: + We strongly recommend you to use:: + + pip install git+https://github.com/lhotse-speech/lhotse + + to install the latest version of lhotse. - If you have installed ``torchaudio``, please consider uninstalling it before - installing ``lhotse``. Otherwise, it may update your already installed PyTorch. (3) Download icefall -------------------- From 2045125fd96a8c0c925f6824d90512e43ac01fb5 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 31 Mar 2022 10:43:02 +0800 Subject: [PATCH 09/16] Fix CI. (#280) * Fix CI. --- .github/workflows/style_check.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml index 2a743705a..6b3d856df 100644 --- a/.github/workflows/style_check.yml +++ b/.github/workflows/style_check.yml @@ -45,7 +45,9 @@ jobs: - name: Install Python dependencies run: | - python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 + python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 click==8.0.4 + # See https://github.com/psf/black/issues/2964 + # The version of click should be selected from 8.0.0, 8.0.1, 8.0.2, 8.0.3, and 8.0.4 - name: Run flake8 shell: bash From fc40bfea8222400ffdcb437d0d4708053a619cb2 Mon Sep 17 00:00:00 2001 From: "LIyong.Guo" <839019390@qq.com> Date: Thu, 31 Mar 2022 10:43:46 +0800 Subject: [PATCH 10/16] fix typo of torch.eig (#281) Co-authored-by: glynpu --- icefall/diagnostics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 08d1628ec..ce4ac1464 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -140,7 +140,7 @@ def get_diagnostics_for_dim( stats = eigs.abs().sqrt() except: # noqa print("Error getting eigenvalues, trying another method.") - eigs, _ = torch.eigs(stats) + eigs = torch.linalg.eigvals(stats) stats = eigs.abs().sqrt() # sqrt so it reflects data magnitude, like stddev- not variance elif sizes_same: From 9a11808ed36b57cb17cfd328f1a8537f86f468a5 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 31 Mar 2022 16:48:46 +0800 Subject: [PATCH 11/16] Set the seed for dataloader. (#282) Also, suppress torch warnings about division by truncation. --- .../ASR/pruned_transducer_stateless/train.py | 7 ++++++- egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 10 ++++++++++ egs/librispeech/ASR/transducer/train.py | 7 ++++++- egs/librispeech/ASR/transducer_lstm/train.py | 7 ++++++- egs/librispeech/ASR/transducer_stateless/conformer.py | 7 +++++-- egs/librispeech/ASR/transducer_stateless/train.py | 7 ++++++- .../asr_datamodule.py | 11 +++++++++++ .../ASR/transducer_stateless_multi_datasets/train.py | 7 ++++++- 8 files changed, 56 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index 1f52370fd..17f82e601 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -33,6 +33,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" import argparse import logging +import warnings from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Tuple @@ -496,7 +497,11 @@ def compute_loss( assert loss.requires_grad == is_training info = MetricsTracker() - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index a460c8eb8..8790b21e7 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -23,6 +23,7 @@ from functools import lru_cache from pathlib import Path from typing import Any, Dict, Optional +import torch from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse.dataset import ( BucketingSampler, @@ -34,6 +35,7 @@ from lhotse.dataset import ( SpecAugment, ) from lhotse.dataset.input_strategies import OnTheFlyFeatures +from lhotse.utils import fix_random_seed from torch.utils.data import DataLoader from icefall.utils import str2bool @@ -301,12 +303,20 @@ class LibriSpeechAsrDataModule: logging.info("Loading sampler state dict") train_sampler.load_state_dict(sampler_state_dict) + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + + def worker_init_fn(worker_id: int): + fix_random_seed(seed + worker_id) + train_dl = DataLoader( train, sampler=train_sampler, batch_size=None, num_workers=self.args.num_workers, persistent_workers=False, + worker_init_fn=worker_init_fn, ) return train_dl diff --git a/egs/librispeech/ASR/transducer/train.py b/egs/librispeech/ASR/transducer/train.py index a6ce79520..cbd9259e0 100755 --- a/egs/librispeech/ASR/transducer/train.py +++ b/egs/librispeech/ASR/transducer/train.py @@ -34,6 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" import argparse import logging +import warnings from pathlib import Path from shutil import copyfile from typing import Optional, Tuple @@ -393,7 +394,11 @@ def compute_loss( assert loss.requires_grad == is_training info = MetricsTracker() - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py index 9f06ed512..eef4d3430 100755 --- a/egs/librispeech/ASR/transducer_lstm/train.py +++ b/egs/librispeech/ASR/transducer_lstm/train.py @@ -35,6 +35,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2" import argparse import logging +import warnings from pathlib import Path from shutil import copyfile from typing import Optional, Tuple @@ -397,7 +398,11 @@ def compute_loss( assert loss.requires_grad == is_training info = MetricsTracker() - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index fc838f75b..488c82386 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -109,8 +109,11 @@ class Conformer(Transformer): x, pos_emb = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - # Caution: We assume the subsampling factor is 4! - lengths = ((x_lens - 1) // 2 - 1) // 2 + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # Caution: We assume the subsampling factor is 4! + lengths = ((x_lens - 1) // 2 - 1) // 2 + assert x.size(0) == lengths.max().item() mask = make_pad_mask(lengths) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 2cc6480d5..d6827c17c 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -34,6 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" import argparse import logging +import warnings from pathlib import Path from shutil import copyfile from typing import Optional, Tuple @@ -419,7 +420,11 @@ def compute_loss( assert loss.requires_grad == is_training info = MetricsTracker() - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py index 669ad1d1b..2ce8d8752 100644 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py @@ -22,6 +22,7 @@ import logging from pathlib import Path from typing import Optional +import torch from lhotse import CutSet, Fbank, FbankConfig from lhotse.dataset import ( BucketingSampler, @@ -34,6 +35,7 @@ from lhotse.dataset.input_strategies import ( OnTheFlyFeatures, PrecomputedFeatures, ) +from lhotse.utils import fix_random_seed from torch.utils.data import DataLoader from icefall.utils import str2bool @@ -253,12 +255,21 @@ class AsrDataModule: ) logging.info("About to create train dataloader") + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + + def worker_init_fn(worker_id: int): + fix_random_seed(seed + worker_id) + train_dl = DataLoader( train, sampler=train_sampler, batch_size=None, num_workers=self.args.num_workers, persistent_workers=False, + worker_init_fn=worker_init_fn, ) return train_dl diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py index 105f82417..5572d3f4c 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py @@ -58,6 +58,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" import argparse import logging import random +import warnings from pathlib import Path from shutil import copyfile from typing import Optional, Tuple @@ -466,7 +467,11 @@ def compute_loss( assert loss.requires_grad == is_training info = MetricsTracker() - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() From e7493ede9069c725e083235b4bfa50bc81e5cf45 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 31 Mar 2022 20:32:00 +0800 Subject: [PATCH 12/16] Don't use a lambda for dataloader's worker_init_fn. (#284) * Don't use a lambda for dataloader's worker_init_fn. --- egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 12 +++++++++--- .../asr_datamodule.py | 12 +++++++++--- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 8790b21e7..8dd1459ca 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -41,6 +41,14 @@ from torch.utils.data import DataLoader from icefall.utils import str2bool +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + class LibriSpeechAsrDataModule: """ DataModule for k2 ASR experiments. @@ -306,9 +314,7 @@ class LibriSpeechAsrDataModule: # 'seed' is derived from the current random state, which will have # previously been set in the main process. seed = torch.randint(0, 100000, ()).item() - - def worker_init_fn(worker_id: int): - fix_random_seed(seed + worker_id) + worker_init_fn = _SeedWorkers(seed) train_dl = DataLoader( train, diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py index 2ce8d8752..c6cf739fb 100644 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py @@ -41,6 +41,14 @@ from torch.utils.data import DataLoader from icefall.utils import str2bool +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + class AsrDataModule: def __init__(self, args: argparse.Namespace): self.args = args @@ -259,9 +267,7 @@ class AsrDataModule: # 'seed' is derived from the current random state, which will have # previously been set in the main process. seed = torch.randint(0, 100000, ()).item() - - def worker_init_fn(worker_id: int): - fix_random_seed(seed + worker_id) + worker_init_fn = _SeedWorkers(seed) train_dl = DataLoader( train, From 0b6a2213c389b2663d1adccb690a3df1f1b1f5a9 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Sat, 2 Apr 2022 15:01:45 +0800 Subject: [PATCH 13/16] Modify icefall/__init__.py. (#287) * Modify icefall/__init__.py to import common functions defined in icefall/utils.py. * Modify icefall/__init__.py and .flake8. --- .flake8 | 3 ++- icefall/__init__.py | 24 ++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/.flake8 b/.flake8 index 229cf1d6c..dd9239b2d 100644 --- a/.flake8 +++ b/.flake8 @@ -13,4 +13,5 @@ per-file-ignores = exclude = .git, **/data/**, - icefall/shared/make_kn_lm.py + icefall/shared/make_kn_lm.py, + icefall/__init__.py diff --git a/icefall/__init__.py b/icefall/__init__.py index e69de29bb..983539d6f 100644 --- a/icefall/__init__.py +++ b/icefall/__init__.py @@ -0,0 +1,24 @@ +from .utils import ( + AttributeDict, + MetricsTracker, + add_eos, + add_sos, + concat, + encode_supervisions, + get_alignments, + get_executor, + get_texts, + l1_norm, + l2_norm, + linf_norm, + load_alignments, + make_pad_mask, + measure_gradient_norms, + measure_weight_norms, + optim_step_and_measure_param_change, + save_alignments, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) From 87cf9231ea73631f1e4453400b3be06d45bcebf5 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 3 Apr 2022 13:02:08 +0800 Subject: [PATCH 14/16] Support specifying iteration number of checkpoints for decoding. (#289) --- .../ASR/pruned_transducer_stateless/decode.py | 55 +++++++++++++------ icefall/checkpoint.py | 43 +++++++++++++-- 2 files changed, 76 insertions(+), 22 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 8e924bf96..49b1308b0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -98,27 +98,28 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + parser.add_argument( "--avg", type=int, default=15, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--avg-last-n", - type=int, - default=0, - help="""If positive, --epoch and --avg are ignored and it - will use the last n checkpoints exp_dir/checkpoint-xxx.pt - where xxx is the number of processed batches while - saving that checkpoint. - """, + "'--epoch' and '--iter'", ) parser.add_argument( @@ -453,13 +454,19 @@ def main(): ) params.res_dir = params.exp_dir / params.decoding_method - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + if "fast_beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -485,8 +492,20 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - if params.avg_last_n > 0: - filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) logging.info(f"averaging {filenames}") model.to(device) model.load_state_dict(average_checkpoints(filenames, device=device)) diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index 251456c95..1ef05d964 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -216,27 +216,62 @@ def save_checkpoint_with_global_batch_idx( ) -def find_checkpoints(out_dir: Path) -> List[str]: +def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]: """Find all available checkpoints in a directory. The checkpoint filenames have the form: `checkpoint-xxx.pt` where xxx is a numerical value. + Assume you have the following checkpoints in the folder `foo`: + + - checkpoint-1.pt + - checkpoint-20.pt + - checkpoint-300.pt + - checkpoint-4000.pt + + Case 1 (Return all checkpoints):: + + find_checkpoints(out_dir='foo') + + Case 2 (Return checkpoints newer than checkpoint-20.pt, i.e., + checkpoint-4000.pt, checkpoint-300.pt, and checkpoint-20.pt) + + find_checkpoints(out_dir='foo', iteration=20) + + Case 3 (Return checkpoints older than checkpoint-20.pt, i.e., + checkpoint-20.pt, checkpoint-1.pt):: + + find_checkpoints(out_dir='foo', iteration=-20) + Args: out_dir: The directory where to search for checkpoints. + iteration: + If it is 0, return all available checkpoints. + If it is positive, return the checkpoints whose iteration number is + greater than or equal to `iteration`. + If it is negative, return the checkpoints whose iteration number is + less than or equal to `-iteration`. Returns: Return a list of checkpoint filenames, sorted in descending order by the numerical value in the filename. """ checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt")) pattern = re.compile(r"checkpoint-([0-9]+).pt") - idx_checkpoints = [ + iter_checkpoints = [ (int(pattern.search(c).group(1)), c) for c in checkpoints ] + # iter_checkpoints is a list of tuples. Each tuple contains + # two elements: (iteration_number, checkpoint-iteration_number.pt) + + iter_checkpoints = sorted( + iter_checkpoints, reverse=True, key=lambda x: x[0] + ) + if iteration >= 0: + ans = [ic[1] for ic in iter_checkpoints if ic[0] >= iteration] + else: + ans = [ic[1] for ic in iter_checkpoints if ic[0] <= -iteration] - idx_checkpoints = sorted(idx_checkpoints, reverse=True, key=lambda x: x[0]) - ans = [ic[1] for ic in idx_checkpoints] return ans From cb3ba16f2bf733e63c8f798b27121da93e58738b Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Tue, 5 Apr 2022 10:22:49 +0800 Subject: [PATCH 15/16] Fix aishell prepare.sh when using pre-download data (#291) --- egs/aishell/ASR/prepare.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index 68f5c54d3..26324b0af 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -70,7 +70,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then # |-- lexicon.txt # `-- speaker.info - if [ ! -d $dl_dir/aishell/data_aishell/wav ]; then + if [ ! -d $dl_dir/aishell/data_aishell/wav/train ]; then lhotse download aishell $dl_dir fi From ceeb95bcb8c12e5047be7f12440296bd9532c0e3 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Wed, 6 Apr 2022 11:55:29 +0800 Subject: [PATCH 16/16] update icefall/__init__.py to import more common functions. (#294) --- icefall/__init__.py | 31 +++++++++++++++++++++++++++++++ pyproject.toml | 1 + 2 files changed, 32 insertions(+) diff --git a/icefall/__init__.py b/icefall/__init__.py index 983539d6f..f466d6a62 100644 --- a/icefall/__init__.py +++ b/icefall/__init__.py @@ -1,3 +1,34 @@ +from .checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, + remove_checkpoints, + save_checkpoint, + save_checkpoint_with_global_batch_idx, +) + +from .decode import ( + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_attention_decoder, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) + +from .dist import ( + cleanup_dist, + setup_dist, +) + +from .env import ( + get_env_info, + get_git_branch_name, + get_git_date, + get_git_sha1, +) + from .utils import ( AttributeDict, MetricsTracker, diff --git a/pyproject.toml b/pyproject.toml index 01ff869db..ec5623f90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,6 @@ [tool.isort] profile = "black" +skip = ["icefall/__init__.py"] [tool.black] line-length = 80